Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] add coordinator to reduce code duplication in tp and pp #5293

Merged
merged 43 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e0101c7
add GroupCoordinator
youkaichao Jun 5, 2024
bc09989
declare fields
youkaichao Jun 5, 2024
bece27b
add tp group
youkaichao Jun 5, 2024
446a952
add allgather
youkaichao Jun 5, 2024
292f850
add gather
youkaichao Jun 5, 2024
4d4f53d
move broadcast in coordinator
youkaichao Jun 5, 2024
390fd0f
move graph_capture into parallel state
youkaichao Jun 5, 2024
9caa4bc
move broadcast_tensor_dict into coordinator
youkaichao Jun 5, 2024
0137663
add first rank and last rank in coordinator
youkaichao Jun 5, 2024
6043382
add next rank and prev rank in coordinator
youkaichao Jun 5, 2024
838773b
remove _PP_GLOBAL_RANKS
youkaichao Jun 5, 2024
3cd1926
add barrier in coordinator
youkaichao Jun 5, 2024
e284955
remove get_cpu_world_group
youkaichao Jun 5, 2024
1ea37d5
clean up
youkaichao Jun 5, 2024
26fff33
remove unused functions
youkaichao Jun 5, 2024
06a94a9
keep writing
youkaichao Jun 5, 2024
38121c8
add _WORLD
youkaichao Jun 5, 2024
617a030
remove _LOCAL_RANK
youkaichao Jun 5, 2024
1a1ff60
add destroy_distributed_environment
youkaichao Jun 5, 2024
b1d4973
remove get local rank
youkaichao Jun 5, 2024
30533a0
enforce device arg
youkaichao Jun 5, 2024
846365e
add comments
youkaichao Jun 5, 2024
2a89158
fix import
youkaichao Jun 5, 2024
15a7f15
hf runner does not need cleanup
youkaichao Jun 5, 2024
cc2c446
add check for torch destroy
youkaichao Jun 5, 2024
dff75cc
fix yield
youkaichao Jun 6, 2024
bc215e4
fix tests
youkaichao Jun 6, 2024
666ad18
fix prefix caching test (only one LLM can exist)
youkaichao Jun 6, 2024
07128a1
allow multiple call of init_distributed_environment
youkaichao Jun 6, 2024
78bf7c3
revert prefix caching test
youkaichao Jun 6, 2024
51409a4
fix arg of broadcast_tensor_dict
youkaichao Jun 6, 2024
3e8ba06
allow calling broadcast_tensor_dict without initializing distributed
youkaichao Jun 6, 2024
36d48d1
remove world warmup
youkaichao Jun 6, 2024
65aa2f9
use coordinator check
youkaichao Jun 7, 2024
9857942
Merge branch 'main' into coordinator_impl
youkaichao Jun 8, 2024
9621f60
Merge branch 'main' into coordinator_impl
youkaichao Jun 10, 2024
cd6d652
fix documentation build error
youkaichao Jun 10, 2024
9b02ec0
Merge branch 'main' into coordinator_impl
youkaichao Jun 12, 2024
c7a8614
add
youkaichao Jun 12, 2024
e22b656
add back cleanup
youkaichao Jun 12, 2024
a976d59
add comment
youkaichao Jun 12, 2024
ef78f38
use get_tp_group get_pp_group get_world_group
youkaichao Jun 12, 2024
2afc1f1
add comments
youkaichao Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove get_cpu_world_group
  • Loading branch information
youkaichao committed Jun 5, 2024
commit e284955ca4498a3f6f8c2e3b91a70e9eb8b9c111
8 changes: 4 additions & 4 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
graph_capture,
get_world, graph_capture,
init_distributed_environment)
from vllm.utils import update_environment_variables

Expand Down Expand Up @@ -54,7 +54,7 @@ def wrapped_fn(env):

@worker_fn_wrapper
def worker_fn():
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world().cpu_group)
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_pynccl_multiple_allreduce_with_vllm():
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world().cpu_group)
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
Expand All @@ -155,7 +155,7 @@ def test_pynccl_with_cudagraph():

@worker_fn_wrapper
def send_recv_worker_fn():
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world().cpu_group)
if pynccl_comm.rank == 0:
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
Expand Down
5 changes: 2 additions & 3 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.distributed.parallel_state import get_local_rank
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -19,7 +19,7 @@ class PyNcclCommunicator:

def __init__(
self,
group: Optional[ProcessGroup] = None,
group: ProcessGroup,
device: Optional[Union[int, str, torch.device]] = None,
library_path: Optional[str] = None,
):
Expand All @@ -35,7 +35,6 @@ def __init__(
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
self.group = group
Expand Down
6 changes: 0 additions & 6 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,6 @@ def model_parallel_is_initialized():
return (_TP is not None and _PP_DEVICE_GROUP is not None)


def get_cpu_world_group():
"""Get the CPU world group."""
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
return _CPU_WORLD_GROUP


def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PP_DEVICE_GROUP is not None, (
Expand Down