Skip to content

Cuda context not properly initialized in side threads when no compilation is needed #3729

Closed
@albanD

Description

See the code below for the full repro. But in this case, the cuda context is never initialized on the side thread since no compilation or loading is needed in the process.
I guess we're missing a way to initialize the context properly for any launch.

From @ngimel's expertise: "all it's missing is strategically placed cudaFree(0)"

import triton
import triton.language as tl
import threading

@triton.jit
def _rua_kernel(hidden_sh_ptr):
    return

def run_kernel():
    _rua_kernel[(1,)](1.)


# Run it once to compile and cache it
run_kernel()

# Run it in a separate thread where cuModuleLoad is not called and
# so the cuda context is never initialized
t = threading.Thread(target=run_kernel)
t.start()
t.join()

Fails with

Exception in thread Thread-1 (run_kernel):
Traceback (most recent call last):
  File "/home/albandes/local/installs/python3.11/release/install/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "/home/albandes/local/installs/python3.11/release/install/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/home/albandes/local/pytorch/3.11_release_source/test/foo.py", line 10, in run_kernel
    _rua_kernel[(1,)](1.)
  File "/home/albandes/local/pytorch/3.11_release_source_env/lib/python3.11/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
RuntimeError: Triton Error [CUDA]: invalid device context

This is causing pytorch/pytorch#124565

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions