Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .parser import parse_file, parse_path, parse_string, Parser
from . import astnodes
from . import dialects
from .visitors import NodeVisitor, NodeTransformer
10 changes: 5 additions & 5 deletions mlir/astnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,15 @@ def dump(self, indent: int = 0) -> str:
return self.value.dump(indent) + (
(':' + dump_or_value(self.count, indent)) if self.count else '')

class Op(Node):
pass



@dataclass
class Operation(Node):
result_list: List[OpResult]
op: "Op"
op: Node
location: Optional["Location"] = None

def dump(self, indent: int = 0) -> str:
Expand All @@ -503,10 +507,6 @@ def dump(self, indent: int = 0) -> str:
return result


class Op(Node):
pass


@dataclass
class GenericOperation(Op):
name: str
Expand Down
13 changes: 7 additions & 6 deletions mlir/builder/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
""" MLIR IR Builder."""

import mlir.astnodes as mast
import mlir.dialects.standard as std
import mlir.dialects.arith as arith
import mlir.dialects.memref as memref
import mlir.dialects.affine as affine
import mlir.dialects.func as func
from typing import Optional, Tuple, Union, List, Any
Expand Down Expand Up @@ -436,28 +437,28 @@ def goto_after(self, query: MatchExpressionBase,

def addf(self, op_a: mast.SsaId, op_b: mast.SsaId, type: mast.Type,
name: Optional[str] = None):
op = std.AddfOperation(match=0, operand_a=op_a, operand_b=op_b, type=type)
op = arith.AddFOperation(match=0, operand_a=op_a, operand_b=op_b, type=type)
return self._insert_op_in_block([name], op)

def mulf(self, op_a: mast.SsaId, op_b: mast.SsaId, type: mast.Type,
name: Optional[str] = None):
op = std.MulfOperation(match=0, operand_a=op_a, operand_b=op_b, type=type)
op = arith.MulFOperation(match=0, operand_a=op_a, operand_b=op_b, type=type)
return self._insert_op_in_block([name], op)

def dim(self, memref_or_tensor: mast.SsaId, index: mast.SsaId,
memref_type: Union[mast.MemRefType, mast.TensorType],
name: Optional[str] = None):
op = std.DimOperation(match=0, operand=memref_or_tensor, index=index,
op = memref.DimOperation(match=0, operand=memref_or_tensor, index=index,
type=memref_type)
return self._insert_op_in_block([name], op)

def index_constant(self, value: int, name: Optional[str] = None):
op = std.ConstantOperation(match=0, value=value, type=mast.IndexType())
op = arith.ConstantOperation(match=0, value=value, type=mast.IndexType())
return self._insert_op_in_block([name], op)

def float_constant(self, value: float, type: mast.FloatType,
name: Optional[str] = None):
op = std.ConstantOperation(match=0, value=value, type=type)
op = arith.ConstantOperation(match=0, value=value, type=type)
return self._insert_op_in_block([name], op)

# }}}
Expand Down
8 changes: 6 additions & 2 deletions mlir/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from .affine import affine as affine_dialect
from .standard import standard as std_dialect
from .cf import cf as cf_dialect
from .math import math as math_dialect
from .tensor import tensor as tensor_dialect
from .arith import arith as arith_dialect
from .scf import scf as scf_dialect
from .linalg import linalg
from .func import func as func_dialect
from .memref import memref as memref_dialect


STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect]
STANDARD_DIALECTS = [affine_dialect, cf_dialect, math_dialect, tensor_dialect, arith_dialect, scf_dialect, linalg, func_dialect, memref_dialect]
121 changes: 121 additions & 0 deletions mlir/dialects/arith.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
""" Implementation of the arith (Arithmetic) dialect. """

import inspect
import sys
from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation, BinaryOperation
import mlir.astnodes as mast
from dataclasses import dataclass
from typing import Optional, List, Tuple, Union

Literal = Union[mast.StringLiteral, float, int, bool]
SsaUse = Union[mast.SsaId, Literal]


# Unary Operations

class BitcastOperation(UnaryOperation): _opname_ = 'arith.bitcast'
class ExtFOperation(UnaryOperation): _opname_ = 'arith.extf'
class ExtSIOperation(UnaryOperation): _opname_ = 'arith.extsi'
class ExtUIOperation(UnaryOperation): _opname_ = 'arith.extui'
class FPToSIOperation(UnaryOperation): _opname_ = 'arith.fptosi'
class FPToUIOperation(UnaryOperation): _opname_ = 'arith.fptoui'
class NegFOperation(UnaryOperation): _opname_ = 'arith.negf'
class SIToFPOperation(UnaryOperation): _opname_ = 'arith.sitofp'
class UIToFPOperation(UnaryOperation): _opname_ = 'arith.uitofp'

# Arithmetic Operations
class AddFOperation(BinaryOperation): _opname_ = 'arith.addf'
class AddIOperation(BinaryOperation): _opname_ = 'arith.addi'
class AndIOperation(BinaryOperation): _opname_ = 'arith.andi'
class CeilDivSIOperation(BinaryOperation): _opname_ = 'arith.ceildivsi'
class CeilDivUIOperation(BinaryOperation): _opname_ = 'arith.ceildivui'
class DivFOperation(BinaryOperation): _opname_ = 'arith.divf'
class DivSIOperation(BinaryOperation): _opname_ = 'arith.divsi'
class DivUIOperation(BinaryOperation): _opname_ = 'arith.divui'
class FloorDivSIOperation(BinaryOperation): _opname_ = 'arith.floordivsi'
class MaximumFOperation(BinaryOperation): _opname_ = 'arith.maximumf'
class MaxNumFOperation(BinaryOperation): _opname_ = 'arith.maxnumf'
class MaxSIOperation(BinaryOperation): _opname_ = 'arith.maxsi'
class MaxUIOperation(BinaryOperation): _opname_ = 'arith.maxui'
class MinimumFOperation(BinaryOperation): _opname_ = 'arith.minimumf'
class MinNumFOperation(BinaryOperation): _opname_ = 'arith.minnumf'
class MinSIOperation(BinaryOperation): _opname_ = 'arith.minsi'
class MinUIOperation(BinaryOperation): _opname_ = 'arith.minui'
class MulFOperation(BinaryOperation): _opname_ = 'arith.mulf'
class MulIOperation(BinaryOperation): _opname_ = 'arith.muli'
class MulSIExtendedOp(BinaryOperation): _opname_ = 'arith.mulsi_extended'
class MulUIExtendedOp(BinaryOperation): _opname_ = 'arith.mului_extended'
class OrIOperation(BinaryOperation): _opname_ = 'arith.ori'
class RemFOperation(BinaryOperation): _opname_ = 'arith.remf'
class RemSIOperation(BinaryOperation): _opname_ = 'arith.remsi'
class RemUIOperation(BinaryOperation): _opname_ = 'arith.remui'
class ShLIOperation(BinaryOperation): _opname_ = 'arith.shli'
class ShRSIOperation(BinaryOperation): _opname_ = 'arith.shrsi'
class ShRUIOperation(BinaryOperation): _opname_ = 'arith.shrui'
class SubIOperation(BinaryOperation): _opname_ = 'arith.subi'
class SubFOperation(BinaryOperation): _opname_ = 'arith.subf'
class TruncFOperation(BinaryOperation): _opname_ = 'arith.truncf'
class TruncIOperation(BinaryOperation): _opname_ = 'arith.trunci'
class XorIOperation(BinaryOperation): _opname_ = 'arith.xori'


@dataclass
class AddUIExtendedOperation(DialectOp):
lhs_operand: mast.SsaId
rhs_operand: mast.SsaId
sum_type: mast.Type
ovf_type = mast.Type
_syntax_ = 'arith.addui_extended {lhs_operand.ssa_id} , {rhs_operand.ssa_id} : {sum_type.type} , {ovf_type.type}'


@dataclass
class CmpiOperation(DialectOp):
comptype: str
operand_a: mast.SsaId
operand_b: mast.SsaId
type: mast.Type
_syntax_ = 'arith.cmpi {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}'


@dataclass
class CmpfOperation(DialectOp):
comptype: str
operand_a: mast.SsaId
operand_b: mast.SsaId
type: mast.Type
_syntax_ = 'arith.cmpf {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}'


@dataclass
class ConstantOperation(DialectOp):
value: Literal
type: mast.Type
_syntax_ = ['arith.constant {value.constant_literal} : {type.type}', 'arith.constant {value.constant_literal}']



@dataclass
class IndexCastOperation(DialectOp):
arg: SsaUse
src_type: mast.Type
dst_type: mast.Type
_syntax_ = 'arith.index_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}'

@dataclass
class IndexCastUIOperation(DialectOp):
arg: SsaUse
src_type: mast.Type
dst_type: mast.Type
_syntax_ = 'arith.index_castui {arg.ssa_use} : {src_type.type} to {dst_type.type}'

@dataclass
class SelectOperation(DialectOp):
cond: SsaUse
arg_true: SsaUse
arg_false: SsaUse
_syntax_ = 'arith.select {cond.ssa_use} , {arg_true.ssa_use} , {arg_false.ssa_use} : {type.type}'


# Inspect current module to get all classes defined above
arith = Dialect('arith', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])
32 changes: 32 additions & 0 deletions mlir/dialects/cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
""" Implementation of the CF (Control Flow) dialect. """

import inspect
import sys
from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation
import mlir.astnodes as mast
from dataclasses import dataclass
from typing import Optional, List, Tuple, Union

Literal = Union[mast.StringLiteral, float, int, bool]
SsaUse = Union[mast.SsaId, Literal]


@dataclass
class BrOperation(DialectOp):
block_id: mast.BlockId
args: Optional[List[Tuple[mast.SsaId, mast.Type]]] = None
_syntax_ = ['cf.br {block.block_id}',
'cf.br {block.block_id} {args.block_arg_list}']


@dataclass
class CondBrOperation(DialectOp):
cond: SsaUse
block_true: mast.BlockId
block_false: mast.BlockId
_syntax_ = ['cf.cond_br {cond.ssa_use} , {block_true.block_id} , {block_false.block_id}']


# Inspect current module to get all classes defined above
cf = Dialect('cf', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])
14 changes: 14 additions & 0 deletions mlir/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ class ConstantOperation(DialectOp):
_syntax_ = ['func.constant {value.symbol_ref_id} : {type.type}']

# Note: The 'func.func' operation is defined as 'function' in mlir.lark.
# not anymore lmfaooooo
@dataclass
class FuncOperation(DialectOp):
name: mast.SymbolRefId
args: Optional[List[mast.NamedArgument]]
result_list: Optional[List[mast.OpResult]] | mast.OpResult
func_mod_attrs: Optional[mast.AttributeDict]
body: Optional[mast.Region]
trail: Optional[mast.Location] = None

_syntax_ = [
'func.func {name.symbol_ref_id} ( {args.optional_arg_list} ) {result_list.optional_fn_result_list} {func_mod_attrs.optional_func_mod_attrs} {body.optional_fn_body}',
'func.func {name.symbol_ref_id} ( {args.optional_arg_list} ) {result_list.optional_fn_result_list} {func_mod_attrs.optional_func_mod_attrs} {body.optional_fn_body} (loc ({trail.optional_location}))']


@dataclass
class ReturnOperation(DialectOp):
Expand Down
20 changes: 20 additions & 0 deletions mlir/dialects/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
""" Implementation of the math (Mathematics) dialect. """

import inspect
import sys
from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation
import mlir.astnodes as mast
from dataclasses import dataclass
from typing import Optional, List, Tuple


# Unary Operations
class AbsfOperation(UnaryOperation): _opname_ = 'math.absf'
class CosOperation(UnaryOperation): _opname_ = 'math.cos'
class ExpOperation(UnaryOperation): _opname_ = 'math.exp'
class TanhOperation(UnaryOperation): _opname_ = 'math.tanh'
class CopysignOperation(UnaryOperation): _opname_ = 'math.copysign'

# Inspect current module to get all classes defined above
math = Dialect('math', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])
Loading