-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid poisoning process with CUDA calls as soon as importing #6810
Conversation
@HollowMan6, thanks for diagnosing the problem and sharing a PR. However, I don't understand why this is the correct solution.
Can you please explain what I am missing? Thanks! |
@HollowMan6, I want to share my thoughts on this problem. Building on your great analysis here and there, I did some further digging to get a better appreciation of the painful status of CUDA availability discovering on PyTorch. However, I think this is a problem that should be fixed in PyTorch rather than DeepSpeed. This is because DeepSpeed builds on PyTorch and should preserve semantics (for good or bad) as much as practical. So, my three takeaways are as follows: 1. In this case, my understanding is that
|
Hi @tjruwase! Thank you for reviewing the PR! Some comments from me:
Yes, that's correct, but I don't think it's a good idea to force users to set the flag in the environment, as https://pytorch.org/docs/stable/notes/cuda.html noted that the NVML-based CUDA availability assessment provides a weaker guarantee than the default CUDA Runtime API approach (which requires CUDA initialization to succeed). In some circumstances, the NVML-based check may succeed while later CUDA initialization fails. (as you noted) From my understanding, according to the comments of the code, which mentions that you only want to determine if we are on a GPU or x86 CPU with a torch instead of hoping to guarantee the CUDA initialization will succeed. But for your concerns, if you do want to ensure that CUDA initialization also work, I can change the check into if torch.cuda.device_count() > 0 and torch.cuda.is_available(): #ignore-cuda So that we first ensure that there are devices available, and then we do checks about CUDA initialization. If not, we shouldn't do any CUDA call. Will update the PR for this.
This is unfortunately a CUDA issue and I don't think PyTorch can do much on their side as well: pytorch/pytorch#141678 (comment) For DeepSpeed, what makes it worse is that we will do a CUDA call as soon as we do an import, which puts developers in a very tough situation when something like OpenRLHF/OpenRLHF#524 (comment) happens again. So, personally, I do hope that this particular case can get fixed on the DeepSpeed side (Don't do any CUDA calls when no CUDA device is available).
Yes, that's true and unfortunate, but the issue can get mitigated when NVML works, so it would still be great to do something on the DeepSpeed side. |
Actually, you are correct, the intention for this logic is device identification rather than initialization. I missed/forgot this nuance when I first reviewed. However, it seems unclear whether CUDA device identification can be generally performed without fork poisoning since |
Will fix the formatting issue now. This failure doesn't seem to relate to this PR (permission denied error) https://github.com/microsoft/DeepSpeed/actions/runs/12240278508/job/34142717128 |
Call `torch.cuda.device_count() > 0` before `torch.cuda.is_available()`, to give priority to nvml based availability, so that we can try not to poison process with CUDA calls as soon as we execute `import deepspeed`. https://github.com/pytorch/pytorch/blob/v2.5.1/torch/cuda/__init__.py#L120-L124 There are 2 reasons to make this change: Firstly, if we accidentally import deepspeed, since the CUDA runtime initializes when the first CUDA API call is made and caches the device list, changing the CUDA_VISIBLE_DEVICES within the same process after initialization won't have any effect on the visible devices. The specific case: OpenRLHF/OpenRLHF#524 (comment) A demo for reproduction before the fix is applied: ```python import torch import os os.environ["CUDA_VISIBLE_DEVICES"] = "" import deepspeed os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" torch.cuda.set_device('cuda:0') ``` Secondly, https://pytorch.org/docs/stable/notes/cuda.html When assessing the availability of CUDA in a given environment (is_available()), PyTorch’s default behavior is to call the CUDA Runtime API method cudaGetDeviceCount. Because this call in turn initializes the CUDA Driver API (via cuInit) if it is not already initialized, subsequent forks of a process that has run is_available() will fail with a CUDA initialization error. Signed-off-by: Hollow Man <hollowman@opensuse.org>
Call
torch.cuda.device_count() > 0
beforetorch.cuda.is_available()
, to give priority to nvml based availability, so that we can try not to poison process with CUDA calls as soon as we executeimport deepspeed
.https://github.com/pytorch/pytorch/blob/v2.5.1/torch/cuda/__init__.py#L120-L124
There are 2 reasons to make this change:
Firstly, if we accidentally import deepspeed, since the CUDA runtime initializes when the first CUDA API call is made and caches the device list, changing the CUDA_VISIBLE_DEVICES within the same process after initialization won't have any effect on the visible devices. The specific case:
OpenRLHF/OpenRLHF#524 (comment)
A demo for reproduction before the fix is applied:
Secondly, https://pytorch.org/docs/stable/notes/cuda.html
When assessing the availability of CUDA in a given environment (is_available()), PyTorch’s default behavior is to call the CUDA Runtime API method cudaGetDeviceCount. Because this call in turn initializes the CUDA Driver API (via cuInit) if it is not already initialized, subsequent forks of a process that has run is_available() will fail with a CUDA initialization error.