Skip to content

[BugFix] Avoid initializing CUDA too early #3487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,12 +577,12 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
if is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked
self.device_type = "cuda"
else:
# Device type is assigned explicitly
self.device_type = device
Expand Down