Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: [help wanted] MoE + TP + custom allreduce bug #9774

Closed
1 task done
youkaichao opened this issue Oct 28, 2024 · 11 comments · Fixed by #9815
Closed
1 task done

[Bug]: [help wanted] MoE + TP + custom allreduce bug #9774

youkaichao opened this issue Oct 28, 2024 · 11 comments · Fixed by #9815
Labels
bug Something isn't working

Comments

@youkaichao
Copy link
Member

Your current environment

main branch, H100

Model Input Dumps

No response

🐛 Describe the bug

when I try this simple command vllm serve allenai/OLMoE-1B-7B-0924 -tp 2 on the main branch, it hits an error:

Failed: Cuda error /workspace/csrc/custom_all_reduce.cuh:336 'invalid argument'

removing -tp works, and disabling custom allreduce also works.

not sure what's happening. if anyone is familiar with the moe code, please help.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@youkaichao youkaichao added the bug Something isn't working label Oct 28, 2024
@cedonley
Copy link

cedonley commented Oct 28, 2024

I'm seeing this with standard Qwen2.5-72b-Instruct-AWQ with TP=2 on 2xA6000

So this is likely not related to the MoE aspect. Adding --disable-custom-all-reduce flag allows the server to come up and answer requests like normal with slightly degraded performance.

Seems to be a new issue, as neither v0.6.3.post1 nor main from last week have this issue.

I can't yet confirm why, but it appears that the upgrade to PyTorch 2.5 in #9588 may be the culprit. Moving to a build just before that commit eliminates the issue and none of the other commits since then look like they would have impacted this.

@youkaichao
Copy link
Member Author

@cedonley thanks for the great information! can you try to confirm which commit leads to this bug? we have all the wheels from the main branch that's ready to directly install, see https://docs.vllm.ai/en/latest/getting_started/installation.html#install-the-latest-code .

@cedonley
Copy link

cedonley commented Oct 28, 2024

@youkaichao Sure. I was just able to confirm the following results:

commit# 3cb07a3 (#9588) - FAILS
commit# 8549c82 (just before) - SUCCESS

The invalid argument error is coming from a call to cudaIpcOpenMemHandle. Did some various internet searching to see what might have changed here between CUDA 12.1 and 12.4 (which is the other factor introduced in #9588) and coming up empty at the moment.

@youkaichao
Copy link
Member Author

@youkaichao
Copy link
Member Author

it looks expandable_segments is False by default.

we need to figure out the memory allocation strategy in pytorch 2.5 .

if you have time, you can also compile yourself, and call cuda APIs to get the pointer's attributes, that should tell you more information.

@cedonley
Copy link

Yes, I had tried (just to be safe) to set it to false with PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False" and it provides the same result. Will try to compile later this evening and see if I can see why we're getting the error.

@cedonley
Copy link

Just added a few lines of code to debug. When I look at the pointer attributes for ipc_handle that is passed into open_ipc_handle, it shows as an unregistered. I'm not a CUDA expert, but I would have expected it to come in as a device handle with an address on the first device.

@cedonley
Copy link

Did a little more analysis last night and it's clear that the flow should be working. Here's what I see:

custom_all_reduce.py

  • allocates the IPC metadata using PyTorch storage (and I see it come back)
  • handles/offsets look fine, rank/world size is correct
  • it then passes this all as parameters to init_custom_ar
  • the error occurs before it registers the buffers and returns

custom_all_reduce.cu

  • init_custom_ar - sees the correct parameters, makes a copy of the handles allocated by PyTorch in the python code, and then passes this into the CustomAllreduce function defined in custom_all_reduce.cuh

custom_all_reduce.cuh

  • CustomAllreduce - Immediately tries to open the handles, which results in the error we see

Not finding any changes in recent PyTorch releases related to how storage is allocated. The metadata that is passed is allocated on python with essentially (simplified):

meta = torch.zeros(ops.meta_size() + max_size,
                                dtype=torch.uint8,
                                device=self.device)
data = meta.untyped_storage()._share_cuda_()
shard_data = (
            data[1],  # ipc handle to base ptr
            data[3],  # offset of base ptr
        )
handles, offsets = self._gather_ipc_meta(shard_data)

I see both the handles and offsets come back and don't see any issues with them.

I tried changing the way processes were spawned, etc... to see if perhaps there was a change to multiprocessing that created a side effect, etc... I'm stumped... Someone with stronger CUDA debugging skills is going to need to have a look at this.

@youkaichao
Copy link
Member Author

thanks for the investigation! let me ask some cuda experts.

@youkaichao
Copy link
Member Author

I think pytorch/pytorch#130888 might be the culprit .

@youkaichao
Copy link
Member Author

okay, they changed the binary format of the ipc handle in pytorch/pytorch#130890

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants