Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] add coordinator to reduce code duplication in tp and pp #5293

Merged
merged 43 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e0101c7
add GroupCoordinator
youkaichao Jun 5, 2024
bc09989
declare fields
youkaichao Jun 5, 2024
bece27b
add tp group
youkaichao Jun 5, 2024
446a952
add allgather
youkaichao Jun 5, 2024
292f850
add gather
youkaichao Jun 5, 2024
4d4f53d
move broadcast in coordinator
youkaichao Jun 5, 2024
390fd0f
move graph_capture into parallel state
youkaichao Jun 5, 2024
9caa4bc
move broadcast_tensor_dict into coordinator
youkaichao Jun 5, 2024
0137663
add first rank and last rank in coordinator
youkaichao Jun 5, 2024
6043382
add next rank and prev rank in coordinator
youkaichao Jun 5, 2024
838773b
remove _PP_GLOBAL_RANKS
youkaichao Jun 5, 2024
3cd1926
add barrier in coordinator
youkaichao Jun 5, 2024
e284955
remove get_cpu_world_group
youkaichao Jun 5, 2024
1ea37d5
clean up
youkaichao Jun 5, 2024
26fff33
remove unused functions
youkaichao Jun 5, 2024
06a94a9
keep writing
youkaichao Jun 5, 2024
38121c8
add _WORLD
youkaichao Jun 5, 2024
617a030
remove _LOCAL_RANK
youkaichao Jun 5, 2024
1a1ff60
add destroy_distributed_environment
youkaichao Jun 5, 2024
b1d4973
remove get local rank
youkaichao Jun 5, 2024
30533a0
enforce device arg
youkaichao Jun 5, 2024
846365e
add comments
youkaichao Jun 5, 2024
2a89158
fix import
youkaichao Jun 5, 2024
15a7f15
hf runner does not need cleanup
youkaichao Jun 5, 2024
cc2c446
add check for torch destroy
youkaichao Jun 5, 2024
dff75cc
fix yield
youkaichao Jun 6, 2024
bc215e4
fix tests
youkaichao Jun 6, 2024
666ad18
fix prefix caching test (only one LLM can exist)
youkaichao Jun 6, 2024
07128a1
allow multiple call of init_distributed_environment
youkaichao Jun 6, 2024
78bf7c3
revert prefix caching test
youkaichao Jun 6, 2024
51409a4
fix arg of broadcast_tensor_dict
youkaichao Jun 6, 2024
3e8ba06
allow calling broadcast_tensor_dict without initializing distributed
youkaichao Jun 6, 2024
36d48d1
remove world warmup
youkaichao Jun 6, 2024
65aa2f9
use coordinator check
youkaichao Jun 7, 2024
9857942
Merge branch 'main' into coordinator_impl
youkaichao Jun 8, 2024
9621f60
Merge branch 'main' into coordinator_impl
youkaichao Jun 10, 2024
cd6d652
fix documentation build error
youkaichao Jun 10, 2024
9b02ec0
Merge branch 'main' into coordinator_impl
youkaichao Jun 12, 2024
c7a8614
add
youkaichao Jun 12, 2024
e22b656
add back cleanup
youkaichao Jun 12, 2024
a976d59
add comment
youkaichao Jun 12, 2024
ef78f38
use get_tp_group get_pp_group get_world_group
youkaichao Jun 12, 2024
2afc1f1
add comments
youkaichao Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add gather
  • Loading branch information
youkaichao committed Jun 5, 2024
commit 292f850e75e2046e7157eb9a5a4119601d102441
36 changes: 3 additions & 33 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import torch
from torch.distributed import ProcessGroup

from .parallel_state import (get_cpu_world_group, get_pp,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tp)
from .parallel_state import get_cpu_world_group, get_pp, get_tp


@contextmanager
Expand Down Expand Up @@ -44,35 +41,8 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group.

NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
"""Gather the input tensor across model parallel group."""
return get_tp().gather(input_, dst, dim)


def broadcast(input_: torch.Tensor,
Expand Down
47 changes: 38 additions & 9 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,39 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
input_size[dim + 1:])
return output_tensor

def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor

def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
Expand Down Expand Up @@ -219,6 +252,9 @@ def get_tp() -> GroupCoordinator:
return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp

_PP: Optional[GroupCoordinator] = None


Expand Down Expand Up @@ -447,12 +483,6 @@ def get_cpu_world_group():
return _CPU_WORLD_GROUP


def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP is not None, ("tensor model parallel group is not initialized")
return _TP


def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PP_DEVICE_GROUP is not None, (
Expand All @@ -469,8 +499,7 @@ def get_pipeline_model_parallel_cpu_group():

def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(
group=get_tensor_model_parallel_group())
return get_tp().world_size


def get_pipeline_model_parallel_world_size():
Expand All @@ -481,7 +510,7 @@ def get_pipeline_model_parallel_world_size():

def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
return get_tp().rank_in_group


def get_pipeline_model_parallel_rank():
Expand Down