Skip to content

Commit

Permalink
[Inductor] Support passing module map parameter to Triton make_ir API (
Browse files Browse the repository at this point in the history
…pytorch#134774)

In triton-lang/triton#4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code.

Fixes pytorch#134674

Pull Request resolved: pytorch#134774
Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel
  • Loading branch information
alexbaden authored and tolleybot committed Sep 14, 2024
1 parent 2a8ff0c commit c57f085
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,18 @@ def generate_ttir(kernel, kwargs):

src = ASTSource(kernel, signature, constants, specialization)

# Triton changes ASTSource.make_ir to take 3 arguments. Handle
# Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
# backward compatibility here.
if len(inspect.signature(src.make_ir).parameters) == 2:
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
if make_ir_sig_params == 2:
ttir_module = src.make_ir(options, context)
else:
elif make_ir_sig_params == 3:
codegen_fns = backend.get_codegen_implementation()
ttir_module = src.make_ir(options, codegen_fns, context)
else:
codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
if not ttir_module.verify():
raise RuntimeError("Verification for TTIR module has failed")

Expand Down

0 comments on commit c57f085

Please sign in to comment.