Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] add coordinator to reduce code duplication in tp and pp #5293

Merged
merged 43 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e0101c7
add GroupCoordinator
youkaichao Jun 5, 2024
bc09989
declare fields
youkaichao Jun 5, 2024
bece27b
add tp group
youkaichao Jun 5, 2024
446a952
add allgather
youkaichao Jun 5, 2024
292f850
add gather
youkaichao Jun 5, 2024
4d4f53d
move broadcast in coordinator
youkaichao Jun 5, 2024
390fd0f
move graph_capture into parallel state
youkaichao Jun 5, 2024
9caa4bc
move broadcast_tensor_dict into coordinator
youkaichao Jun 5, 2024
0137663
add first rank and last rank in coordinator
youkaichao Jun 5, 2024
6043382
add next rank and prev rank in coordinator
youkaichao Jun 5, 2024
838773b
remove _PP_GLOBAL_RANKS
youkaichao Jun 5, 2024
3cd1926
add barrier in coordinator
youkaichao Jun 5, 2024
e284955
remove get_cpu_world_group
youkaichao Jun 5, 2024
1ea37d5
clean up
youkaichao Jun 5, 2024
26fff33
remove unused functions
youkaichao Jun 5, 2024
06a94a9
keep writing
youkaichao Jun 5, 2024
38121c8
add _WORLD
youkaichao Jun 5, 2024
617a030
remove _LOCAL_RANK
youkaichao Jun 5, 2024
1a1ff60
add destroy_distributed_environment
youkaichao Jun 5, 2024
b1d4973
remove get local rank
youkaichao Jun 5, 2024
30533a0
enforce device arg
youkaichao Jun 5, 2024
846365e
add comments
youkaichao Jun 5, 2024
2a89158
fix import
youkaichao Jun 5, 2024
15a7f15
hf runner does not need cleanup
youkaichao Jun 5, 2024
cc2c446
add check for torch destroy
youkaichao Jun 5, 2024
dff75cc
fix yield
youkaichao Jun 6, 2024
bc215e4
fix tests
youkaichao Jun 6, 2024
666ad18
fix prefix caching test (only one LLM can exist)
youkaichao Jun 6, 2024
07128a1
allow multiple call of init_distributed_environment
youkaichao Jun 6, 2024
78bf7c3
revert prefix caching test
youkaichao Jun 6, 2024
51409a4
fix arg of broadcast_tensor_dict
youkaichao Jun 6, 2024
3e8ba06
allow calling broadcast_tensor_dict without initializing distributed
youkaichao Jun 6, 2024
36d48d1
remove world warmup
youkaichao Jun 6, 2024
65aa2f9
use coordinator check
youkaichao Jun 7, 2024
9857942
Merge branch 'main' into coordinator_impl
youkaichao Jun 8, 2024
9621f60
Merge branch 'main' into coordinator_impl
youkaichao Jun 10, 2024
cd6d652
fix documentation build error
youkaichao Jun 10, 2024
9b02ec0
Merge branch 'main' into coordinator_impl
youkaichao Jun 12, 2024
c7a8614
add
youkaichao Jun 12, 2024
e22b656
add back cleanup
youkaichao Jun 12, 2024
a976d59
add comment
youkaichao Jun 12, 2024
ef78f38
use get_tp_group get_pp_group get_world_group
youkaichao Jun 12, 2024
2afc1f1
add comments
youkaichao Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalData
Expand Down Expand Up @@ -54,6 +55,7 @@ def _read_prompts(filename: str) -> List[str]:

def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
Expand Down Expand Up @@ -359,7 +361,6 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup()


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch.distributed as dist

from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_ca_communicator)
get_tp, graph_capture)

from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
Expand Down Expand Up @@ -91,7 +91,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
# communicate independently
num_communication = rank // tp_size + 1
sz = 1024
fa = get_tp_ca_communicator()
fa = get_tp().ca_comm
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = inp
for _ in range(num_communication):
Expand Down
12 changes: 8 additions & 4 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch.distributed

from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
get_world, graph_capture,
init_distributed_environment)
from vllm.utils import update_environment_variables

Expand Down Expand Up @@ -53,7 +54,8 @@ def wrapped_fn(env):

@worker_fn_wrapper
def worker_fn():
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world().cpu_group,
device=get_world().device)
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
Expand Down Expand Up @@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world().cpu_group,
device=get_world().device)
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
Expand All @@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():

@worker_fn_wrapper
def send_recv_worker_fn():
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world().cpu_group,
device=get_world().device)
if pynccl_comm.rank == 0:
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
Expand Down
23 changes: 13 additions & 10 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

import vllm
from vllm.config import LoRAConfig
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
Expand All @@ -35,6 +38,7 @@

def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
Expand Down Expand Up @@ -64,15 +68,14 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):

@pytest.fixture
def dist_init():
if not torch.distributed.is_initialized():
temp_file = tempfile.mkstemp()[1]
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file://{temp_file}",
)
torch.distributed.all_reduce(torch.zeros(1).cuda())
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
cleanup()
Expand Down
4 changes: 3 additions & 1 deletion tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import torch

from vllm.distributed.parallel_state import init_distributed_environment
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -292,6 +293,7 @@ def distributed_init():
rank=0,
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
local_rank=0)
ensure_model_parallel_initialized(1, 1)


@pytest.mark.parametrize("batch_size", list(range(2, 128)))
Expand Down
Loading
Loading