Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYTHON] Fixed bug in scoping mechanism #111

Merged
merged 1 commit into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions include/triton/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@ class global_value;
class alloc_const;

/* Module */
struct scope {
public:
const std::map<std::string, ir::value*>& get_values() { return values; }
void set_type(const std::string& name, ir::type* ty) { types[name] = ty; }
ir::type* get_type(const std::string& name) { return types.at(name); }
private:
std::map<std::string, ir::type*> types;
std::map<std::string, ir::value*> values;
};

class module {
typedef std::pair<std::string, basic_block*> val_key_t;
Expand Down Expand Up @@ -74,8 +65,11 @@ class module {
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::map<val_key_t, value*>& get_values() { return values_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
Expand All @@ -84,10 +78,6 @@ class module {
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
function *get_or_insert_function(const std::string &name, function_type *ty);
// Scope
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
void pop_scope() { scopes_.pop(); }
scope& get_scope() { return scopes_.top(); }
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
Expand All @@ -101,15 +91,14 @@ class module {
std::string name_;
builder& builder_;
std::map<val_key_t, value*> values_;
std::map<val_key_t, type*> types_;
std::map<std::string, type*> types_;
std::set<std::string> const_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_;
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::stack<scope> scopes_;
std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
Expand Down
2 changes: 1 addition & 1 deletion lib/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
ir::value *result;
bool is_const = const_.find(name) != const_.end();
auto &preds = block->get_predecessors();
ir::type *ty = get_scope().get_type(name);
ir::type *ty = types_.at(name);
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];
Expand Down
11 changes: 3 additions & 8 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,15 @@ void init_triton_ir(py::module &&m) {
.def_property_readonly("shape", &ir::block_type::get_shapes)
.def_property_readonly("numel", &ir::type::get_tile_num_elements);

py::class_<ir::scope>(m, "scope")
.def(py::init<>())
.def_property_readonly("values", &ir::scope::get_values)
.def("set_type", &ir::scope::set_type);

py::class_<ir::module>(m, "module")
.def(py::init<std::string, ir::builder &>())
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("add_new_scope", &ir::module::add_new_scope, ret::reference)
.def("seal_block", &ir::module::seal_block)
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
.def("set_type", &ir::module::set_type)
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
.def("pop_scope", &ir::module::pop_scope)
.def_property_readonly("scope", &ir::module::get_scope, ret::reference)
.def("get_values", &ir::module::get_values, ret::reference)
.def("set_values", &ir::module::set_values)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);

using eattr = ir::attribute_kind_t;
Expand Down
29 changes: 14 additions & 15 deletions python/triton/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,17 @@ def set_value(self, name, value):
value = triton.language.block(value)
if isinstance(value, triton.language.block):
self.module.set_value(name, value.handle)
self.module.scope.set_type(name, value.handle.type)
self.module.set_type(name, value.handle.type)
self.lscope[name] = value

def is_triton_object(self, value):
return isinstance(value, triton.language.block)

def visit_compound_statement(self, stmts, add_scope=False):
if add_scope:
self.module.add_new_scope()
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
if add_scope:
self.module.pop_scope()
return stmts and isinstance(stmt, ast.Return)

def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
Expand All @@ -75,9 +71,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
}

def visit_Module(self, node):
self.module.add_new_scope()
ast.NodeVisitor.generic_visit(self, node)
self.module.pop_scope()

def visit_List(self, node):
ctx = self.visit(node.ctx)
Expand Down Expand Up @@ -117,14 +111,14 @@ def visit_FunctionDef(self, node, inline=False, arg_values=None):
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
self.visit_compound_statement(node.body, add_scope=True)
self.visit_compound_statement(node.body)
return self.last_ret
else:
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.module.seal_block(entry)
self.builder.set_insert_block(entry)
# visit function body
self.visit_compound_statement(node.body, add_scope=True)
self.visit_compound_statement(node.body)
# finalize function
self.builder.ret_void()

Expand Down Expand Up @@ -216,13 +210,13 @@ def visit_If(self, node):
else:
self.builder.cond_br(cond.handle, then_bb, endif_bb)
self.builder.set_insert_block(then_bb)
is_terminator = self.visit_compound_statement(node.body, add_scope=True)
is_terminator = self.visit_compound_statement(node.body)
# TODO: last statement is a terminator?
if not is_terminator:
self.builder.br(endif_bb)
if else_bb:
self.builder.set_insert_block(else_bb)
is_terminator = self.visit_compound_statement(node.orelse, add_scope=True)
is_terminator = self.visit_compound_statement(node.orelse)
#TODO: last statement is a terminator?
if not is_terminator:
self.builder.br(endif_bb)
Expand Down Expand Up @@ -289,7 +283,7 @@ def continue_fn():

continue_fn()
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body, add_scope=True)
self.visit_compound_statement(node.body)
continue_fn()
stop_bb = self.builder.get_insert_block()
self.module.seal_block(stop_bb)
Expand Down Expand Up @@ -344,7 +338,7 @@ def continue_fn():
cond = build_cond()
self.builder.cond_br(cond.handle, loop_bb, next_bb)
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body, add_scope=True)
self.visit_compound_statement(node.body)
# TODO: handle case where body breaks control flow
continue_fn()
stop_bb = self.builder.get_insert_block()
Expand Down Expand Up @@ -643,7 +637,12 @@ def parse(self):

def __call__(self, *args, generator: CodeGenerator, **meta):
try:
return generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
generator.lscope = lscope
generator.module.set_values(values)
return ret
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
Expand Down