
# Autogenerated by mlir-tblgen; don't manually edit.

from enum import IntEnum, auto, IntFlag
from ._ods_common import _cext as _ods_cext
from ..ir import register_attribute_builder
_ods_ir = _ods_cext.ir

class LoadCacheModifierKind(IntEnum):
    """NVVM load cache modifier kind"""

    CA = 0
    CG = 1
    CS = 2
    LU = 3
    CV = 4

    def __str__(self):
        if self is LoadCacheModifierKind.CA:
            return "ca"
        if self is LoadCacheModifierKind.CG:
            return "cg"
        if self is LoadCacheModifierKind.CS:
            return "cs"
        if self is LoadCacheModifierKind.LU:
            return "lu"
        if self is LoadCacheModifierKind.CV:
            return "cv"
        raise ValueError("Unknown LoadCacheModifierKind enum entry.")



@register_attribute_builder("LoadCacheModifierKind")
def _loadcachemodifierkind(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class MMAB1Op(IntEnum):
    """MMA binary operations"""

    none = 0
    xor_popc = 1
    and_popc = 2

    def __str__(self):
        if self is MMAB1Op.none:
            return "none"
        if self is MMAB1Op.xor_popc:
            return "xor_popc"
        if self is MMAB1Op.and_popc:
            return "and_popc"
        raise ValueError("Unknown MMAB1Op enum entry.")



@register_attribute_builder("MMAB1Op")
def _mmab1op(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class MMAFrag(IntEnum):
    """NVVM MMA frag type"""

    a = 0
    b = 1
    c = 2

    def __str__(self):
        if self is MMAFrag.a:
            return "a"
        if self is MMAFrag.b:
            return "b"
        if self is MMAFrag.c:
            return "c"
        raise ValueError("Unknown MMAFrag enum entry.")



@register_attribute_builder("MMAFrag")
def _mmafrag(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class MMAIntOverflow(IntEnum):
    """MMA overflow options"""

    satfinite = 1
    wrapped = 0

    def __str__(self):
        if self is MMAIntOverflow.satfinite:
            return "satfinite"
        if self is MMAIntOverflow.wrapped:
            return "wrapped"
        raise ValueError("Unknown MMAIntOverflow enum entry.")



@register_attribute_builder("MMAIntOverflow")
def _mmaintoverflow(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class MMALayout(IntEnum):
    """NVVM MMA layout"""

    row = 0
    col = 1

    def __str__(self):
        if self is MMALayout.row:
            return "row"
        if self is MMALayout.col:
            return "col"
        raise ValueError("Unknown MMALayout enum entry.")



@register_attribute_builder("MMALayout")
def _mmalayout(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class MMATypes(IntEnum):
    """NVVM MMA types"""

    f16 = 0
    f32 = 1
    tf32 = 2
    bf16 = 9
    s8 = 4
    u8 = 3
    s32 = 5
    s4 = 8
    u4 = 7
    b1 = 6
    f64 = 10

    def __str__(self):
        if self is MMATypes.f16:
            return "f16"
        if self is MMATypes.f32:
            return "f32"
        if self is MMATypes.tf32:
            return "tf32"
        if self is MMATypes.bf16:
            return "bf16"
        if self is MMATypes.s8:
            return "s8"
        if self is MMATypes.u8:
            return "u8"
        if self is MMATypes.s32:
            return "s32"
        if self is MMATypes.s4:
            return "s4"
        if self is MMATypes.u4:
            return "u4"
        if self is MMATypes.b1:
            return "b1"
        if self is MMATypes.f64:
            return "f64"
        raise ValueError("Unknown MMATypes enum entry.")



@register_attribute_builder("MMATypes")
def _mmatypes(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class ProxyKind(IntEnum):
    """Proxy kind"""

    alias = 0
    async_ = 1
    async_global = 2
    async_shared = 3

    def __str__(self):
        if self is ProxyKind.alias:
            return "alias"
        if self is ProxyKind.async_:
            return "async"
        if self is ProxyKind.async_global:
            return "async.global"
        if self is ProxyKind.async_shared:
            return "async.shared"
        raise ValueError("Unknown ProxyKind enum entry.")



@register_attribute_builder("ProxyKind")
def _proxykind(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class ReduxKind(IntEnum):
    """NVVM redux kind"""

    ADD = 1
    AND = 2
    MAX = 3
    MIN = 4
    OR = 5
    UMAX = 6
    UMIN = 7
    XOR = 8

    def __str__(self):
        if self is ReduxKind.ADD:
            return "add"
        if self is ReduxKind.AND:
            return "and"
        if self is ReduxKind.MAX:
            return "max"
        if self is ReduxKind.MIN:
            return "min"
        if self is ReduxKind.OR:
            return "or"
        if self is ReduxKind.UMAX:
            return "umax"
        if self is ReduxKind.UMIN:
            return "umin"
        if self is ReduxKind.XOR:
            return "xor"
        raise ValueError("Unknown ReduxKind enum entry.")



@register_attribute_builder("ReduxKind")
def _reduxkind(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class SetMaxRegisterAction(IntEnum):
    """NVVM set max register action"""

    decrease = 1
    increase = 0

    def __str__(self):
        if self is SetMaxRegisterAction.decrease:
            return "decrease"
        if self is SetMaxRegisterAction.increase:
            return "increase"
        raise ValueError("Unknown SetMaxRegisterAction enum entry.")



@register_attribute_builder("SetMaxRegisterAction")
def _setmaxregisteraction(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class SharedSpace(IntEnum):
    """Shared memory space"""

    shared_cta = 0
    shared_cluster = 1

    def __str__(self):
        if self is SharedSpace.shared_cta:
            return "cta"
        if self is SharedSpace.shared_cluster:
            return "cluster"
        raise ValueError("Unknown SharedSpace enum entry.")



@register_attribute_builder("SharedSpace")
def _sharedspace(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class ShflKind(IntEnum):
    """NVVM shuffle kind"""

    bfly = 0
    up = 1
    down = 2
    idx = 3

    def __str__(self):
        if self is ShflKind.bfly:
            return "bfly"
        if self is ShflKind.up:
            return "up"
        if self is ShflKind.down:
            return "down"
        if self is ShflKind.idx:
            return "idx"
        raise ValueError("Unknown ShflKind enum entry.")



@register_attribute_builder("ShflKind")
def _shflkind(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class WGMMAScaleIn(IntEnum):
    """WGMMA overflow options"""

    one = 1
    neg = auto()

    def __str__(self):
        if self is WGMMAScaleIn.one:
            return "one"
        if self is WGMMAScaleIn.neg:
            return "neg"
        raise ValueError("Unknown WGMMAScaleIn enum entry.")



@register_attribute_builder("WGMMAScaleIn")
def _wgmmascalein(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class WGMMAScaleOut(IntEnum):
    """WGMMA input predicate"""

    zero = 0
    one = 1

    def __str__(self):
        if self is WGMMAScaleOut.zero:
            return "zero"
        if self is WGMMAScaleOut.one:
            return "one"
        raise ValueError("Unknown WGMMAScaleOut enum entry.")



@register_attribute_builder("WGMMAScaleOut")
def _wgmmascaleout(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class WGMMATypes(IntEnum):
    """NVVM WGMMA types"""

    f16 = 0
    tf32 = 1
    u8 = 2
    s8 = 3
    b1 = 4
    bf16 = 5
    e4m3 = 6
    e5m2 = 7
    f32 = 8
    s32 = 9

    def __str__(self):
        if self is WGMMATypes.f16:
            return "f16"
        if self is WGMMATypes.tf32:
            return "tf32"
        if self is WGMMATypes.u8:
            return "u8"
        if self is WGMMATypes.s8:
            return "s8"
        if self is WGMMATypes.b1:
            return "b1"
        if self is WGMMATypes.bf16:
            return "bf16"
        if self is WGMMATypes.e4m3:
            return "e4m3"
        if self is WGMMATypes.e5m2:
            return "e5m2"
        if self is WGMMATypes.f32:
            return "f32"
        if self is WGMMATypes.s32:
            return "s32"
        raise ValueError("Unknown WGMMATypes enum entry.")



@register_attribute_builder("WGMMATypes")
def _wgmmatypes(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

@register_attribute_builder("LoadCacheModifierAttr")
def _loadcachemodifierattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm<load_cache_modifier {str(x)}>', context=context)

@register_attribute_builder("MMAB1OpAttr")
def _mmab1opattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.mma_b1op<{str(x)}>', context=context)

@register_attribute_builder("MMAFragAttr")
def _mmafragattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.mma_frag<{str(x)}>', context=context)

@register_attribute_builder("MMAIntOverflowAttr")
def _mmaintoverflowattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.mma_int_overflow<{str(x)}>', context=context)

@register_attribute_builder("MMALayoutAttr")
def _mmalayoutattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.mma_layout<{str(x)}>', context=context)

@register_attribute_builder("MMATypesAttr")
def _mmatypesattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.mma_type<{str(x)}>', context=context)

@register_attribute_builder("ProxyKindAttr")
def _proxykindattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.proxy_kind<{str(x)}>', context=context)

@register_attribute_builder("ReduxKindAttr")
def _reduxkindattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm<redux_kind {str(x)}>', context=context)

@register_attribute_builder("SetMaxRegisterActionAttr")
def _setmaxregisteractionattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm<action {str(x)}>', context=context)

@register_attribute_builder("SharedSpaceAttr")
def _sharedspaceattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.shared_space<{str(x)}>', context=context)

@register_attribute_builder("ShflKindAttr")
def _shflkindattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm<shfl_kind {str(x)}>', context=context)

@register_attribute_builder("WGMMAScaleInAttr")
def _wgmmascaleinattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.wgmma_scale_in<{str(x)}>', context=context)

@register_attribute_builder("WGMMAScaleOutAttr")
def _wgmmascaleoutattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.wgmma_scale_out<{str(x)}>', context=context)

@register_attribute_builder("WGMMATypesAttr")
def _wgmmatypesattr(x, context):
    return _ods_ir.Attribute.parse(f'#nvvm.wgmma_type<{str(x)}>', context=context)

