Skip to content

Commit

Permalink
Add mechanism for remapping device-specific module imports
Browse files Browse the repository at this point in the history
This is motivated by #4509. The crux of the problem is that the Triton
code generator needs to inspect a function's arguments / attributes /
types in order to determine how it should be called. This meant that
"implementation details" like whether a function is a builtin needed to
be exposed in the "interface" `tl.extra.libdevice` module, instead of
just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that
libdevice functions marked as @core.extern in the interface could not be
implemented via JitFunctions.

Allowing each backend to provide its own module map solves this problem
as the code generator can inspect the actual function implementation.
  • Loading branch information
int3 committed Aug 19, 2024
1 parent 6a5638e commit ffc68da
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 635 deletions.
3 changes: 2 additions & 1 deletion python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def walk_fn(op):
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())
codegen_fns = dict()
module_map = backend.get_module_map()
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)

ttir_module = src.make_ir(options, codegen_fns, context)
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
ttir_module.walk(walk_fn)


Expand Down
26 changes: 20 additions & 6 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def visit_Call(self, node: ast.Call) -> bool:
class CodeGenerator(ast.NodeVisitor):

def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None,
noinline=False, file_name: Optional[str] = None, begin_line=0):
codegen_fns, module_map, debug=None, module=None, is_kernel=False,
function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0):
self.context = context
self.builder = ir.builder(context)
self.file_name = file_name
Expand All @@ -201,10 +201,23 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
# Convert custom types not natively supported on HW.
# convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
self.builder.codegen_fns = codegen_fns
self.builder.module_map = {} if module_map is None else module_map
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = {} if function_types is None else function_types
self.prototype = prototype
self.gscope = gscope

self.gscope = {}
for k, v in gscope.items():
if isinstance(v, ModuleType):
self.gscope[k] = module_map.get(v.__name__, v)
continue

module_name = getattr(v, "__module__", "")
if module_name in module_map:
self.gscope[k] = getattr(module_map[module_name], k)
else:
self.gscope[k] = v

self.lscope = {}
self.attributes = attributes
self.constants = constants
Expand Down Expand Up @@ -1049,7 +1062,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug)
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
module_map=self.builder.module_map, debug=debug)
try:
generator.visit(fn.parse())
except Exception as e:
Expand Down Expand Up @@ -1252,7 +1266,7 @@ def kernel_suffix(signature, specialization):
return suffix


def ast_to_ttir(fn, specialization, context, options, codegen_fns):
def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
attrs = specialization.attrs
# create kernel prototype
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
Expand All @@ -1272,7 +1286,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns):
prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
begin_line=begin_line, options=options, codegen_fns=codegen_fns)
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
generator.visit(fn.parse())

ret = generator.module
Expand Down
10 changes: 6 additions & 4 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def hash(self):
key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, context):
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
def make_ir(self, options, codegen_fns, module_map, context):
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
module_map=module_map)

def parse_options(self):
return dict()
Expand All @@ -132,7 +133,7 @@ def __init__(self, path):
def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, context):
def make_ir(self, options, codegen_fns, module_map, context):
module = ir.parse_mlir_module(self.path, context)
module.context = context
return module
Expand Down Expand Up @@ -277,8 +278,9 @@ def compile(src, target=None, options=None):
ir.load_dialects(context)
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
try:
module = src.make_ir(options, codegen_fns, context)
module = src.make_ir(options, codegen_fns, module_map, context)
except Exception as e:
filter_traceback(e)
raise
Expand Down
Loading

0 comments on commit ffc68da

Please sign in to comment.