Skip to content

Unit Test for run_dp_sharded_vision_model #19103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@

import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
from PIL import Image, ImageChops

from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata)
merge_and_sort_multimodal_metadata,
run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables

if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
Expand Down Expand Up @@ -399,3 +408,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes


class SimpleLinearModel(torch.nn.Module):
"""A simple linear vision model for testing."""

def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
super().__init__()
self.flatten = torch.nn.Flatten()
self.linear = torch.nn.Linear(input_dim, output_dim)

def forward(self, x: torch.Tensor):
# Flatten the input and apply linear transformation
x = self.flatten(x)
return self.linear(x)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
4, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)


def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
batch_size: int, master_port: int):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""

# Set random seed for reproducibility
current_platform.seed_everything(0)

device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)

update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})

# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)

# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)

# Create a simple linear model
vision_model = SimpleLinearModel()

# Run the model directly on the full input
with torch.inference_mode():
direct_output = vision_model(image_input)

# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should assert get_tensor_model_parallel_world_size() == world_size


# Check that the world size is setup correctly
assert get_tensor_model_parallel_world_size() == world_size

# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape

# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)