-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mechanism for remapping device-specific module imports #4539
Conversation
I initially tried to tackle #4509 with a purely "userspace" solution by changing |
fixed test |
This is motivated by triton-lang#4509. The crux of the problem is that the Triton code generator needs to inspect a function's arguments / attributes / types in order to determine how it should be called. This meant that "implementation details" like whether a function is a builtin needed to be exposed in the "interface" `tl.extra.libdevice` module, instead of just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that libdevice functions marked as @core.extern in the interface could not be implemented via JitFunctions. Allowing each backend to provide its own module map solves this problem as the code generator can inspect the actual function implementation.
This is motivated by #4509. The crux of the problem is that the Triton code generator needs to inspect a function's arguments / attributes / types in order to determine how it should be called. This meant that "implementation details" like whether a function is a builtin needed to be exposed in the "interface" `tl.extra.libdevice` module, instead of just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that libdevice functions marked as `@core.extern` in the interface could not be implemented via JitFunctions. Allowing each backend to provide its own module map solves this problem as the code generator can inspect the actual function implementation.
…#134774) In triton-lang/triton#4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code. Fixes #134674 Pull Request resolved: #134774 Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel
…pytorch#134774) In triton-lang/triton#4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code. Fixes pytorch#134674 Pull Request resolved: pytorch#134774 Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel
…pytorch#134774) In triton-lang/triton#4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code. Fixes pytorch#134674 Pull Request resolved: pytorch#134774 Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel
Context: In `CodeGenerator.__init__`, globals for a given triton function are modified to handle remapping the libdevice module to cuda or hip (from triton-lang#4539). In particular, this logic: ```python for k, v in gscope.items(): # gscope is a dict of fn.__globals__ ... self.gscope[k] = getattr(module_map[module_name], k) ``` was failing if you do this in the global scope: `from triton.language.extras.libdevice import fast_dividef as my_fast_dividef`.
…5081) Context: in `CodeGenerator.__init__`, globals for a given triton function are modified to handle remapping the libdevice module to cuda or hip (from #4539). In particular, this logic: ```python for k, v in gscope.items(): # gscope is a dict of fn.__globals__ ... self.gscope[k] = getattr(module_map[module_name], k) ``` was failing if you do this in the global scope: `from triton.language.extras.libdevice import fast_dividef as my_fast_dividef`.
…5081) Context: in `CodeGenerator.__init__`, globals for a given triton function are modified to handle remapping the libdevice module to cuda or hip (from #4539). In particular, this logic: ```python for k, v in gscope.items(): # gscope is a dict of fn.__globals__ ... self.gscope[k] = getattr(module_map[module_name], k) ``` was failing if you do this in the global scope: `from triton.language.extras.libdevice import fast_dividef as my_fast_dividef`.
…riton-lang#5081) Context: in `CodeGenerator.__init__`, globals for a given triton function are modified to handle remapping the libdevice module to cuda or hip (from triton-lang#4539). In particular, this logic: ```python for k, v in gscope.items(): # gscope is a dict of fn.__globals__ ... self.gscope[k] = getattr(module_map[module_name], k) ``` was failing if you do this in the global scope: `from triton.language.extras.libdevice import fast_dividef as my_fast_dividef`.
…ng#4539) This is motivated by triton-lang#4509. The crux of the problem is that the Triton code generator needs to inspect a function's arguments / attributes / types in order to determine how it should be called. This meant that "implementation details" like whether a function is a builtin needed to be exposed in the "interface" `tl.extra.libdevice` module, instead of just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that libdevice functions marked as `@core.extern` in the interface could not be implemented via JitFunctions. Allowing each backend to provide its own module map solves this problem as the code generator can inspect the actual function implementation.
This is motivated by #4509. The crux of the problem is that the Triton code generator needs to inspect a function's arguments / attributes / types in order to determine how it should be called. This meant that "implementation details" like whether a function is a builtin needed to be exposed in the "interface"
tl.extra.libdevice
module, instead of just residing intl.extra.cuda.libdevice
. Moreover, this meant that libdevice functions marked as@core.extern
in the interface could not be implemented via JitFunctions.Allowing each backend to provide its own module map solves this problem as the code generator can inspect the actual function implementation.