Cuda context not properly initialized in side threads when no compilation is needed #3729
Closed
Description
See the code below for the full repro. But in this case, the cuda context is never initialized on the side thread since no compilation or loading is needed in the process.
I guess we're missing a way to initialize the context properly for any launch.
From @ngimel's expertise: "all it's missing is strategically placed cudaFree(0)"
import triton
import triton.language as tl
import threading
@triton.jit
def _rua_kernel(hidden_sh_ptr):
return
def run_kernel():
_rua_kernel[(1,)](1.)
# Run it once to compile and cache it
run_kernel()
# Run it in a separate thread where cuModuleLoad is not called and
# so the cuda context is never initialized
t = threading.Thread(target=run_kernel)
t.start()
t.join()
Fails with
Exception in thread Thread-1 (run_kernel):
Traceback (most recent call last):
File "/home/albandes/local/installs/python3.11/release/install/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
self.run()
File "/home/albandes/local/installs/python3.11/release/install/lib/python3.11/threading.py", line 975, in run
self._target(*self._args, **self._kwargs)
File "/home/albandes/local/pytorch/3.11_release_source/test/foo.py", line 10, in run_kernel
_rua_kernel[(1,)](1.)
File "/home/albandes/local/pytorch/3.11_release_source_env/lib/python3.11/site-packages/triton/runtime/jit.py", line 550, in run
bin.c_wrapper(
RuntimeError: Triton Error [CUDA]: invalid device context
This is causing pytorch/pytorch#124565