Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Add send and recv helpers #5719

Merged
merged 14 commits into from
Jun 23, 2024
Prev Previous commit
Next Next commit
Refactor send and recv functions and add new test
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
  • Loading branch information
andoorve committed Jun 22, 2024
commit 31ce144feedb6239b36cc1c4435ab42aeb3e6016
25 changes: 24 additions & 1 deletion tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

size = 64
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")

if not is_pipeline_model_parallel_first_rank():
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)

if not is_pipeline_model_parallel_last_rank():
get_pp_group().send(test_tensor)

if not is_pipeline_model_parallel_first_rank():
assert torch.allclose(test_tensor, recv_tensor)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
Expand All @@ -160,6 +182,7 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [send_recv_tensor_dict_test_worker])
@pytest.mark.parametrize(
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)
Loading