-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[distributed] add function to create ipc buffers directly #10064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # can only run on machines with p2p access across GPUs | ||
| # can only run with torchrun: | ||
| # torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py | ||
|
|
||
| import ctypes | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary | ||
| from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa | ||
| CustomAllreduce) | ||
|
|
||
| # create a cpu process group for communicating metadata (ipc handle) | ||
| dist.init_process_group(backend="gloo") | ||
| rank = local_rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
|
|
||
| # every process sets its own device (differently) | ||
| lib = CudaRTLibrary() | ||
| lib.cudaSetDevice(rank) | ||
|
|
||
| buffer_size_in_bytes = 1024 | ||
| byte_value = 2 # the value we write to the buffer for verification | ||
|
|
||
| pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes) | ||
|
|
||
| print(f"Rank {rank} has pointers {pointers}") | ||
|
|
||
| dist.barrier() | ||
| torch.cuda.synchronize() | ||
|
|
||
| if rank == 0: | ||
| # the first rank tries to write to all buffers | ||
| for p in pointers: | ||
| pointer = ctypes.c_void_p(p) | ||
| lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes) | ||
|
|
||
| dist.barrier() | ||
| torch.cuda.synchronize() | ||
|
|
||
| host_data = (ctypes.c_char * buffer_size_in_bytes)() | ||
|
|
||
| # all ranks read from all buffers, and check if the data is correct | ||
| for p in pointers: | ||
| pointer = ctypes.c_void_p(p) | ||
| lib.cudaMemcpy(host_data, pointer, buffer_size_in_bytes) | ||
| for i in range(buffer_size_in_bytes): | ||
| assert ord(host_data[i]) == byte_value, ( | ||
| f"Rank {rank} failed" | ||
| f" to verify buffer {p}. Expected {byte_value}, " | ||
| f"got {ord(host_data[i])}") | ||
|
|
||
| print(f"Rank {rank} verified all buffers") | ||
|
|
||
| dist.barrier() | ||
| torch.cuda.synchronize() | ||
|
|
||
| CustomAllreduce.free_shared_buffer(pointers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to use broadcast with
device=cpuin_gather_ipc_metabut not here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when you use this function, the
groupargument should becpu_grouppassed to custom allreduce object.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see
vllm/vllm/distributed/parallel_state.py
Lines 231 to 236 in 4089985
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is
all_gatherfine here but not in_gather_ipc_meta?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, that is because we met some issues with
all_gatherfor tensors directly. here we are usingall_gather_object, so it should be fine. see pytorch/pytorch#126032 for the pytorch issue.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Someone refactored this to return a Tensor
https://github.com/vllm-project/vllm/pull/5047/files#diff-44d9d733ee604800cbce9858a9201db1044aeff2c940fa4a0521d0c9b6541b3eL137
A better way should be returning a string, if torch bindings doesn't support int8 directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah string should be fine. #5047 aims to get rid of pybind11 so that we can release python version agnostic wheels.