Skip to content

Commit

Permalink
Add mechanism for remapping device-specific module imports (#4539)
Browse files Browse the repository at this point in the history
This is motivated by #4509. The crux of the problem is that the Triton
code generator needs to inspect a function's arguments / attributes /
types in order to determine how it should be called. This meant that
"implementation details" like whether a function is a builtin needed to
be exposed in the "interface" `tl.extra.libdevice` module, instead of
just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that
libdevice functions marked as `@core.extern` in the interface could not
be implemented via JitFunctions.

Allowing each backend to provide its own module map solves this problem
as the code generator can inspect the actual function implementation.
  • Loading branch information
int3 authored and Jokeren committed Aug 24, 2024
1 parent 17a30bd commit c2acb7b
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 638 deletions.
3 changes: 2 additions & 1 deletion python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def walk_fn(op):
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())
codegen_fns = dict()
module_map = backend.get_module_map()
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)

ttir_module = src.make_ir(options, codegen_fns, context)
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
ttir_module.walk(walk_fn)


Expand Down
10 changes: 9 additions & 1 deletion python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from abc import ABCMeta, abstractmethod, abstractclassmethod
from dataclasses import dataclass
from typing import Union
from typing import Dict, Union
from types import ModuleType


@dataclass(frozen=True)
Expand Down Expand Up @@ -74,3 +75,10 @@ def load_dialects(self, context):
Load additional MLIR dialects into the provided `context`
"""
raise NotImplementedError

@abstractmethod
def get_module_map(self) -> Dict[str, ModuleType]:
"""
Return a map of interface modules to their device-specific implementations.
"""
raise NotImplementedError
26 changes: 20 additions & 6 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def visit_Call(self, node: ast.Call) -> bool:
class CodeGenerator(ast.NodeVisitor):

def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None,
noinline=False, file_name: Optional[str] = None, begin_line=0):
codegen_fns, module_map, debug=None, module=None, is_kernel=False,
function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0):
self.context = context
self.builder = ir.builder(context)
self.file_name = file_name
Expand All @@ -201,10 +201,23 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
# Convert custom types not natively supported on HW.
# convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
self.builder.codegen_fns = codegen_fns
self.builder.module_map = {} if module_map is None else module_map
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = {} if function_types is None else function_types
self.prototype = prototype
self.gscope = gscope

self.gscope = {}
for k, v in gscope.items():
if isinstance(v, ModuleType):
self.gscope[k] = module_map.get(v.__name__, v)
continue

module_name = getattr(v, "__module__", "")
if module_name in module_map:
self.gscope[k] = getattr(module_map[module_name], k)
else:
self.gscope[k] = v

self.lscope = {}
self.attributes = attributes
self.constants = constants
Expand Down Expand Up @@ -1054,7 +1067,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug)
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
module_map=self.builder.module_map, debug=debug)
try:
generator.visit(fn.parse())
except Exception as e:
Expand Down Expand Up @@ -1257,7 +1271,7 @@ def kernel_suffix(signature, specialization):
return suffix


def ast_to_ttir(fn, specialization, context, options, codegen_fns):
def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
attrs = specialization.attrs
# create kernel prototype
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
Expand All @@ -1277,7 +1291,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns):
prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
begin_line=begin_line, options=options, codegen_fns=codegen_fns)
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
generator.visit(fn.parse())

ret = generator.module
Expand Down
10 changes: 6 additions & 4 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def hash(self):
key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, context):
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
def make_ir(self, options, codegen_fns, module_map, context):
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
module_map=module_map)

def parse_options(self):
return dict()
Expand All @@ -132,7 +133,7 @@ def __init__(self, path):
def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, context):
def make_ir(self, options, codegen_fns, module_map, context):
module = ir.parse_mlir_module(self.path, context)
module.context = context
return module
Expand Down Expand Up @@ -277,8 +278,9 @@ def compile(src, target=None, options=None):
ir.load_dialects(context)
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
try:
module = src.make_ir(options, codegen_fns, context)
module = src.make_ir(options, codegen_fns, module_map, context)
except Exception as e:
filter_traceback(e)
raise
Expand Down
Loading

0 comments on commit c2acb7b

Please sign in to comment.