
# Autogenerated by mlir-tblgen; don't manually edit.

from ._ods_common import _cext as _ods_cext
from ._ods_common import (
    equally_sized_accessor as _ods_equally_sized_accessor,
    get_default_loc_context as _ods_get_default_loc_context,
    get_op_result_or_op_results as _get_op_result_or_op_results,
    get_op_result_or_value as _get_op_result_or_value,
    get_op_results_or_values as _get_op_results_or_values,
    segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir

import builtins
from typing import Sequence as _Sequence, Union as _Union


@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
  DIALECT_NAMESPACE = "sdy"

@_ods_cext.register_operation(_Dialect)
class DataFlowEdgeOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.data_flow_edge"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, *, sharding=None, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(_get_op_result_or_value(input))
    _ods_context = _ods_get_default_loc_context(loc)
    if sharding is not None: attributes["sharding"] = (sharding if (
        isinstance(sharding, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def sharding(self):
    if "sharding" not in self.operation.attributes:
      return None
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value):
    if value is not None:
      self.operation.attributes["sharding"] = value
    elif "sharding" in self.operation.attributes:
      del self.operation.attributes["sharding"]

  @sharding.deleter
  def sharding(self):
    del self.operation.attributes["sharding"]

  @builtins.property
  def result(self):
    return self.operation.results[0]

def data_flow_edge(input, *, sharding=None, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(DataFlowEdgeOp(input=input, sharding=sharding, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class IdentityOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.identity"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(_get_op_result_or_value(input))
    _ods_context = _ods_get_default_loc_context(loc)
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def result(self):
    return self.operation.results[0]

def identity(input, *, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(IdentityOp(input=input, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class PropagationBarrierOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.propagation_barrier"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, allowed_direction, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(_get_op_result_or_value(input))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["allowed_direction"] = (allowed_direction if (
    isinstance(allowed_direction, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_PropagationDirection')) else
      _ods_ir.AttrBuilder.get('Sdy_PropagationDirection')(allowed_direction, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def allowed_direction(self):
    return self.operation.attributes["allowed_direction"]

  @allowed_direction.setter
  def allowed_direction(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["allowed_direction"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def propagation_barrier(input, allowed_direction, *, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(PropagationBarrierOp(input=input, allowed_direction=allowed_direction, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class ConstantOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.constant"

  _ODS_REGIONS = (0, True)

  def __init__(self, value, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["value"] = (value if (
    isinstance(value, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('ElementsAttr')) else
      _ods_ir.AttrBuilder.get('ElementsAttr')(value, context=_ods_context))
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def value(self):
    return self.operation.attributes["value"]

  @value.setter
  def value(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["value"] = value

  @builtins.property
  def output(self):
    return self.operation.results[0]

def constant(value, *, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(ConstantOp(value=value, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class ManualComputationOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.manual_computation"

  _ODS_REGIONS = (1, True)

  def __init__(self, results_, tensors, in_shardings, out_shardings, manual_axes, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["in_shardings"] = (in_shardings if (
    isinstance(in_shardings, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
    attributes["out_shardings"] = (out_shardings if (
    isinstance(out_shardings, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
    attributes["manual_axes"] = (manual_axes if (
    isinstance(manual_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ManualAxes')) else
      _ods_ir.AttrBuilder.get('Sdy_ManualAxes')(manual_axes, context=_ods_context))
    results.extend(results_)
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def tensors(self):
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def in_shardings(self):
    return self.operation.attributes["in_shardings"]

  @in_shardings.setter
  def in_shardings(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["in_shardings"] = value

  @builtins.property
  def out_shardings(self):
    return self.operation.attributes["out_shardings"]

  @out_shardings.setter
  def out_shardings(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_shardings"] = value

  @builtins.property
  def manual_axes(self):
    return self.operation.attributes["manual_axes"]

  @manual_axes.setter
  def manual_axes(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["manual_axes"] = value

  @builtins.property
  def results_(self):
    _ods_variadic_group_length = len(self.operation.results) - 1 + 1
    return self.operation.results[0:0 + _ods_variadic_group_length]

  @builtins.property
  def body(self):
    return self.regions[0]

def manual_computation(results_, tensors, in_shardings, out_shardings, manual_axes, *, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(ManualComputationOp(results_=results_, tensors=tensors, in_shardings=in_shardings, out_shardings=out_shardings, manual_axes=manual_axes, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class MeshOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.mesh"

  _ODS_REGIONS = (0, True)

  def __init__(self, sym_name, mesh, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sym_name"] = (sym_name if (
    isinstance(sym_name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('SymbolNameAttr')) else
      _ods_ir.AttrBuilder.get('SymbolNameAttr')(sym_name, context=_ods_context))
    attributes["mesh"] = (mesh if (
    isinstance(mesh, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_Mesh')) else
      _ods_ir.AttrBuilder.get('Sdy_Mesh')(mesh, context=_ods_context))
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def sym_name(self):
    return self.operation.attributes["sym_name"]

  @sym_name.setter
  def sym_name(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sym_name"] = value

  @builtins.property
  def mesh(self):
    return self.operation.attributes["mesh"]

  @mesh.setter
  def mesh(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["mesh"] = value

def mesh(sym_name, mesh, *, loc=None, ip=None) -> _ods_ir.Operation:
  return _get_op_result_or_op_results(MeshOp(sym_name=sym_name, mesh=mesh, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class ReshardOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.reshard"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(_get_op_result_or_value(input))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sharding"] = (sharding if (
    isinstance(sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def sharding(self):
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def reshard(input, sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(ReshardOp(input=input, sharding=sharding, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class ReturnOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.return"

  _ODS_REGIONS = (0, True)

  def __init__(self, results_, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(results_))
    _ods_context = _ods_get_default_loc_context(loc)
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def results_(self):
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

def return_(results_, *, loc=None, ip=None) -> _ods_ir.Operation:
  return _get_op_result_or_op_results(ReturnOp(results_=results_, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class ShardingConstraintOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.sharding_constraint"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(_get_op_result_or_value(input))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sharding"] = (sharding if (
    isinstance(sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def sharding(self):
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def sharding_constraint(input, sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(ShardingConstraintOp(input=input, sharding=sharding, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class ShardingGroupOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.sharding_group"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, group_id, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(_get_op_result_or_value(input))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["group_id"] = (group_id if (
    isinstance(group_id, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('I64Attr')) else
      _ods_ir.AttrBuilder.get('I64Attr')(group_id, context=_ods_context))
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def group_id(self):
    return self.operation.attributes["group_id"]

  @group_id.setter
  def group_id(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["group_id"] = value

def sharding_group(input, group_id, *, loc=None, ip=None) -> _ods_ir.Operation:
  return _get_op_result_or_op_results(ShardingGroupOp(input=input, group_id=group_id, loc=loc, ip=ip))
