-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: [help wanted] MoE + TP + custom allreduce bug #9774
Comments
I'm seeing this with standard Qwen2.5-72b-Instruct-AWQ with TP=2 on 2xA6000 So this is likely not related to the MoE aspect. Adding --disable-custom-all-reduce flag allows the server to come up and answer requests like normal with slightly degraded performance. Seems to be a new issue, as neither v0.6.3.post1 nor main from last week have this issue. I can't yet confirm why, but it appears that the upgrade to PyTorch 2.5 in #9588 may be the culprit. Moving to a build just before that commit eliminates the issue and none of the other commits since then look like they would have impacted this. |
@cedonley thanks for the great information! can you try to confirm which commit leads to this bug? we have all the wheels from the main branch that's ready to directly install, see https://docs.vllm.ai/en/latest/getting_started/installation.html#install-the-latest-code . |
@youkaichao Sure. I was just able to confirm the following results: commit# 3cb07a3 (#9588) - FAILS The invalid argument error is coming from a call to cudaIpcOpenMemHandle. Did some various internet searching to see what might have changed here between CUDA 12.1 and 12.4 (which is the other factor introduced in #9588) and coming up empty at the moment. |
might be caused by the new expandable segment: |
it looks we need to figure out the memory allocation strategy in pytorch 2.5 . if you have time, you can also compile yourself, and call cuda APIs to get the pointer's attributes, that should tell you more information. |
Yes, I had tried (just to be safe) to set it to false with PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False" and it provides the same result. Will try to compile later this evening and see if I can see why we're getting the error. |
Just added a few lines of code to debug. When I look at the pointer attributes for ipc_handle that is passed into open_ipc_handle, it shows as an unregistered. I'm not a CUDA expert, but I would have expected it to come in as a device handle with an address on the first device. |
Did a little more analysis last night and it's clear that the flow should be working. Here's what I see: custom_all_reduce.py
custom_all_reduce.cu
custom_all_reduce.cuh
Not finding any changes in recent PyTorch releases related to how storage is allocated. The metadata that is passed is allocated on python with essentially (simplified):
I see both the handles and offsets come back and don't see any issues with them. I tried changing the way processes were spawned, etc... to see if perhaps there was a change to multiprocessing that created a side effect, etc... I'm stumped... Someone with stronger CUDA debugging skills is going to need to have a look at this. |
thanks for the investigation! let me ask some cuda experts. |
I think pytorch/pytorch#130888 might be the culprit . |
okay, they changed the binary format of the ipc handle in pytorch/pytorch#130890 |
Your current environment
main branch, H100
Model Input Dumps
No response
🐛 Describe the bug
when I try this simple command
vllm serve allenai/OLMoE-1B-7B-0924 -tp 2
on the main branch, it hits an error:Failed: Cuda error /workspace/csrc/custom_all_reduce.cuh:336 'invalid argument'
removing
-tp
works, and disabling custom allreduce also works.not sure what's happening. if anyone is familiar with the moe code, please help.
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: