Skip to content

Commit

Permalink
[PYTHON] Various minor codegen fixes (triton-lang#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Jul 27, 2021
1 parent 2b75158 commit 4290be1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 22 deletions.
5 changes: 4 additions & 1 deletion lib/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
ir::phi_node* phi = make_phi(ty, 1, block);
set_value(name, block, phi);
result = add_phi_operands(name, phi);
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
result = try_remove_trivial_phis(phi);
}
set_value(name, block, result);
return result;
}
Expand Down
3 changes: 3 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
.def_property_readonly("value", &ir::constant_fp::get_value);

py::class_<ir::instruction, ir::user>(m, "instruction");
py::class_<ir::phi_node, ir::user>(m, "phi_node");

py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
Expand Down
22 changes: 15 additions & 7 deletions python/triton/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def visit_compound_statement(self, stmts, add_scope=False):
break
if add_scope:
self.module.pop_scope()
return self.last_ret
return stmts and isinstance(stmt, ast.Return)

def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.builder = _triton.ir.builder(context)
Expand Down Expand Up @@ -85,7 +85,10 @@ def visit_List(self, node):

# By design, only non-kernel functions can return
def visit_Return(self, node):
return self.visit(node.value)
ret = self.visit(node.value)
if ret is None:
return self.builder.ret_void()
return ret

def visit_FunctionDef(self, node, inline=False, arg_values=None):
arg_names, kwarg_names = self.visit(node.args)
Expand All @@ -112,7 +115,8 @@ def visit_FunctionDef(self, node, inline=False, arg_values=None):
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
return self.visit_compound_statement(node.body, add_scope=True)
self.visit_compound_statement(node.body, add_scope=True)
return self.last_ret
else:
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.module.seal_block(entry)
Expand Down Expand Up @@ -140,6 +144,8 @@ def visit_Assign(self, node):
assert len(names) == 1
name = names[0]
value = self.visit(node.value)
if not isinstance(value, triton.language.block):
value = triton.language._to_ir(value, self.builder)
self.set_value(names[0], value)

def visit_AugAssign(self, node):
Expand Down Expand Up @@ -208,14 +214,16 @@ def visit_If(self, node):
else:
self.builder.cond_br(cond.handle, then_bb, endif_bb)
self.builder.set_insert_block(then_bb)
self.visit_compound_statement(node.body, add_scope=True)
is_terminator = self.visit_compound_statement(node.body, add_scope=True)
# TODO: last statement is a terminator?
self.builder.br(endif_bb)
if not is_terminator:
self.builder.br(endif_bb)
if else_bb:
self.builder.set_insert_block(else_bb)
self.visit_compound_statement(node.orelse, add_scope=True)
is_terminator = self.visit_compound_statement(node.orelse, add_scope=True)
#TODO: last statement is a terminator?
self.builder.br(endif_bb)
if not is_terminator:
self.builder.br(endif_bb)
self.module.seal_block(endif_bb)
self.builder.set_insert_block(endif_bb)
else:
Expand Down
29 changes: 15 additions & 14 deletions python/triton/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
from functools import wraps


def _patch(fn):
# convert block/dtype to ir values
def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
return builder.get_int32(x)
elif isinstance(x, float):
return builder.get_float32(x)
if isinstance(x, block):
return x.handle
if isinstance(x, dtype):
return x.handle(builder)
return x

# convert block/dtype to ir values
def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
return builder.get_int32(x)
elif isinstance(x, float):
return builder.get_float32(x)
if isinstance(x, block):
return x.handle
if isinstance(x, dtype):
return x.handle(builder)
return x

def _patch(fn):
def _from_ir(x):
if isinstance(x, ir.value):
if x.type.is_void():
Expand Down Expand Up @@ -306,6 +306,7 @@ def zeros(shape, dtype, builder=None):
:param dtype: Data-type of the new array, e.g., :code:`triton.float16`
:type dtype: DType
"""
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
return frontend.zeros(shape, dtype, builder)


Expand Down

0 comments on commit 4290be1

Please sign in to comment.