Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][distributed] fix custom allreduce in pytorch 2.5 #9815

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,20 @@ def capture(self):

def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
handle = data[1]
# https://github.com/pytorch/pytorch/pull/130890 changes
# the binary format of the ipc handle
# it starts from pytorch 2.5
if len(handle) > 64:
assert len(handle) == 66
# only support SHAREABLE_HANDLE_VERSION = 1
assert int(handle[0]) == 1
# only support SHAREABLE_CUDA_MALLOC = 'c'
assert handle[1] == ord("c")
handle = handle[2:]
# TODO: support expandable segment
shard_data = (
data[1], # ipc handle to base ptr
handle, # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
Expand Down