|  | 
| 2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | 
| 3 | 3 | 
 | 
| 4 | 4 | import base64 | 
|  | 5 | +import math | 
| 5 | 6 | import mimetypes | 
| 6 | 7 | import os | 
| 7 | 8 | from tempfile import NamedTemporaryFile, TemporaryDirectory | 
|  | 
| 20 | 21 | from vllm.multimodal.image import convert_image_mode | 
| 21 | 22 | from vllm.multimodal.inputs import PlaceholderRange | 
| 22 | 23 | from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, | 
|  | 24 | +                                   get_load_balance_assignment, | 
|  | 25 | +                                   run_dp_sharded_mrope_vision_model, | 
| 23 | 26 |                                    run_dp_sharded_vision_model) | 
| 24 | 27 | from vllm.platforms import current_platform | 
| 25 | 28 | from vllm.utils import get_open_port, update_environment_variables | 
| @@ -425,8 +428,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, | 
| 425 | 428 |     # Set random seed for reproducibility | 
| 426 | 429 |     current_platform.seed_everything(0) | 
| 427 | 430 | 
 | 
| 428 |  | -    device = torch.device(f"cuda:{local_rank}") | 
| 429 |  | -    torch.cuda.set_device(device) | 
|  | 431 | +    device = f"{current_platform.device_name}:{local_rank}" | 
|  | 432 | +    current_platform.set_device(device) | 
| 430 | 433 |     torch.set_default_device(device) | 
| 431 | 434 | 
 | 
| 432 | 435 |     update_environment_variables({ | 
| @@ -463,3 +466,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, | 
| 463 | 466 | 
 | 
| 464 | 467 |     # Check that the outputs are close (they should be identical) | 
| 465 | 468 |     assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) | 
|  | 469 | + | 
|  | 470 | + | 
|  | 471 | +@pytest.mark.parametrize( | 
|  | 472 | +    "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," | 
|  | 473 | +    "expected_grouped_sizes_per_gpu,test_description", | 
|  | 474 | +    [ | 
|  | 475 | +        # Empty input | 
|  | 476 | +        ([], 2, [], [0, 0], [0, 0], "empty input"), | 
|  | 477 | +
 | 
|  | 478 | +        # Fewer samples than GPUs | 
|  | 479 | +        ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 | 
|  | 480 | +                                               ], "fewer samples than GPUs"), | 
|  | 481 | +
 | 
|  | 482 | +        # Single GPU | 
|  | 483 | +        ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), | 
|  | 484 | +
 | 
|  | 485 | +        # Balanced assignment | 
|  | 486 | +        ([100, 100, 100, 100 | 
|  | 487 | +          ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), | 
|  | 488 | +
 | 
|  | 489 | +        # Unbalanced sizes - this one is trickier since the algorithm is greedy | 
|  | 490 | +        ([1000, 100, 200, 50], 2, [0, 2, 1, 3 | 
|  | 491 | +                                   ], [1, 3], [1000, 350], "unbalanced sizes"), | 
|  | 492 | +    ], | 
|  | 493 | +) | 
|  | 494 | +def test_get_load_balance_assignment_cases(sizes, num_gpus, | 
|  | 495 | +                                           expected_shuffle_indices, | 
|  | 496 | +                                           expected_gpu_sample_counts, | 
|  | 497 | +                                           expected_grouped_sizes_per_gpu, | 
|  | 498 | +                                           test_description): | 
|  | 499 | +    """Test get_load_balance_assignment with various input cases.""" | 
|  | 500 | +    result = get_load_balance_assignment(sizes, num_gpus=num_gpus) | 
|  | 501 | +    (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result | 
|  | 502 | + | 
|  | 503 | +    # Common assertions for all cases | 
|  | 504 | +    assert len(shuffle_indices) == len(sizes) | 
|  | 505 | +    assert len(gpu_sample_counts) == num_gpus | 
|  | 506 | +    assert len(grouped_sizes_per_gpu) == num_gpus | 
|  | 507 | +    assert sum(gpu_sample_counts) == len(sizes) | 
|  | 508 | + | 
|  | 509 | +    assert shuffle_indices == expected_shuffle_indices | 
|  | 510 | + | 
|  | 511 | +    assert gpu_sample_counts == expected_gpu_sample_counts | 
|  | 512 | +    assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu | 
|  | 513 | + | 
|  | 514 | + | 
|  | 515 | +class SimpleMRopeVisionModel(torch.nn.Module): | 
|  | 516 | +    """A simple vision model for testing mrope functionality.""" | 
|  | 517 | + | 
|  | 518 | +    def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): | 
|  | 519 | +        super().__init__() | 
|  | 520 | +        self.spatial_merge_size = spatial_merge_size | 
|  | 521 | +        self.out_hidden_size = out_hidden_size | 
|  | 522 | +        self.linear = torch.nn.Linear(768, out_hidden_size) | 
|  | 523 | + | 
|  | 524 | +    def forward(self, pixel_values: torch.Tensor, | 
|  | 525 | +                grid_thw_list: list[list[int]]): | 
|  | 526 | +        """Simple forward pass that simulates spatial merging.""" | 
|  | 527 | +        # Apply linear transformation | 
|  | 528 | +        embeddings = self.linear(pixel_values) | 
|  | 529 | + | 
|  | 530 | +        # Simulate spatial merging by reducing the number of patches | 
|  | 531 | +        merge_factor = self.spatial_merge_size * self.spatial_merge_size | 
|  | 532 | + | 
|  | 533 | +        # Group patches and merge spatially | 
|  | 534 | +        merged_embeddings = [] | 
|  | 535 | +        start_idx = 0 | 
|  | 536 | + | 
|  | 537 | +        for grid_thw in grid_thw_list: | 
|  | 538 | +            num_patches = math.prod(grid_thw) | 
|  | 539 | +            end_idx = start_idx + num_patches | 
|  | 540 | + | 
|  | 541 | +            # Get patches for this image | 
|  | 542 | +            image_patches = embeddings[start_idx:end_idx] | 
|  | 543 | + | 
|  | 544 | +            # Simulate spatial merging by averaging groups of patches | 
|  | 545 | +            merged_patches = num_patches // merge_factor | 
|  | 546 | +            if merged_patches > 0: | 
|  | 547 | +                # Reshape and average to simulate merging | 
|  | 548 | +                reshaped = image_patches[:merged_patches * merge_factor].view( | 
|  | 549 | +                    merged_patches, merge_factor, -1) | 
|  | 550 | +                merged = reshaped.mean(dim=1) | 
|  | 551 | +                merged_embeddings.append(merged) | 
|  | 552 | + | 
|  | 553 | +            start_idx = end_idx | 
|  | 554 | + | 
|  | 555 | +        if merged_embeddings: | 
|  | 556 | +            return torch.cat(merged_embeddings, dim=0) | 
|  | 557 | +        else: | 
|  | 558 | +            return torch.empty((0, self.out_hidden_size), | 
|  | 559 | +                               device=pixel_values.device, | 
|  | 560 | +                               dtype=pixel_values.dtype) | 
|  | 561 | + | 
|  | 562 | + | 
|  | 563 | +@multi_gpu_test(num_gpus=2) | 
|  | 564 | +@pytest.mark.parametrize( | 
|  | 565 | +    "batch_size", | 
|  | 566 | +    [ | 
|  | 567 | +        1,  # Single image | 
|  | 568 | +        3,  # Small batch | 
|  | 569 | +        5,  # Odd batch size (for testing padding) | 
|  | 570 | +    ], | 
|  | 571 | +) | 
|  | 572 | +def test_run_dp_sharded_mrope_vision_model(batch_size: int): | 
|  | 573 | +    world_size = 2 | 
|  | 574 | +    # Launch processes | 
|  | 575 | +    mp.spawn( | 
|  | 576 | +        run_dp_sharded_mrope_vision_model_vs_direct, | 
|  | 577 | +        args=( | 
|  | 578 | +            world_size, | 
|  | 579 | +            batch_size, | 
|  | 580 | +            get_open_port(), | 
|  | 581 | +        ), | 
|  | 582 | +        nprocs=world_size, | 
|  | 583 | +    ) | 
|  | 584 | + | 
|  | 585 | + | 
|  | 586 | +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, | 
|  | 587 | +                                                world_size: int, | 
|  | 588 | +                                                batch_size: int, | 
|  | 589 | +                                                master_port: int): | 
|  | 590 | +    """ | 
|  | 591 | +    Test that run_dp_sharded_mrope_vision_model produces the same results as  | 
|  | 592 | +    calling the model directly. | 
|  | 593 | +    """ | 
|  | 594 | +    # Set random seed for reproducibility | 
|  | 595 | +    current_platform.seed_everything(0) | 
|  | 596 | +    device = f"{current_platform.device_name}:{local_rank}" | 
|  | 597 | +    current_platform.set_device(device) | 
|  | 598 | +    torch.set_default_device(device) | 
|  | 599 | + | 
|  | 600 | +    update_environment_variables({ | 
|  | 601 | +        'RANK': str(local_rank), | 
|  | 602 | +        'LOCAL_RANK': str(local_rank), | 
|  | 603 | +        'WORLD_SIZE': str(world_size), | 
|  | 604 | +        'MASTER_ADDR': 'localhost', | 
|  | 605 | +        'MASTER_PORT': str(master_port), | 
|  | 606 | +    }) | 
|  | 607 | + | 
|  | 608 | +    # initialize distributed | 
|  | 609 | +    init_distributed_environment() | 
|  | 610 | +    initialize_model_parallel(tensor_model_parallel_size=world_size) | 
|  | 611 | + | 
|  | 612 | +    # Create test data | 
|  | 613 | +    grid_thw_list = [] | 
|  | 614 | +    pixel_values_list = [] | 
|  | 615 | + | 
|  | 616 | +    for i in range(batch_size): | 
|  | 617 | +        # Varying image sizes for better testing | 
|  | 618 | +        t, h, w = 1, 4 + i, 4 + i | 
|  | 619 | +        grid_thw_list.append([t, h, w]) | 
|  | 620 | + | 
|  | 621 | +        num_patches = t * h * w | 
|  | 622 | +        # Create random pixel values for this image | 
|  | 623 | +        image_pixels = torch.randn(num_patches, 768) | 
|  | 624 | +        pixel_values_list.append(image_pixels) | 
|  | 625 | + | 
|  | 626 | +    # Concatenate all pixel values | 
|  | 627 | +    pixel_values = torch.cat(pixel_values_list, dim=0) | 
|  | 628 | + | 
|  | 629 | +    # Create a simple mrope vision model | 
|  | 630 | +    vision_model = SimpleMRopeVisionModel() | 
|  | 631 | + | 
|  | 632 | +    # Run the model directly on the full input (only on rank 0) | 
|  | 633 | +    if local_rank == 0: | 
|  | 634 | +        with torch.inference_mode(): | 
|  | 635 | +            direct_output = vision_model(pixel_values, grid_thw_list) | 
|  | 636 | + | 
|  | 637 | +    # Run the model through the sharded function | 
|  | 638 | +    with torch.inference_mode(): | 
|  | 639 | +        sharded_output = run_dp_sharded_mrope_vision_model( | 
|  | 640 | +            vision_model, pixel_values, grid_thw_list) | 
|  | 641 | +        sharded_output = torch.cat(sharded_output, dim=0) | 
|  | 642 | + | 
|  | 643 | +    # Check that the world size is setup correctly | 
|  | 644 | +    assert get_tensor_model_parallel_world_size() == world_size | 
|  | 645 | + | 
|  | 646 | +    # Compare outputs (only on rank 0) | 
|  | 647 | +    if local_rank == 0: | 
|  | 648 | +        # Check that the outputs have the same shape | 
|  | 649 | +        assert direct_output.shape == sharded_output.shape | 
|  | 650 | +        # Check that the outputs are close (they should be identical) | 
|  | 651 | +        assert torch.allclose(direct_output, | 
|  | 652 | +                              sharded_output, | 
|  | 653 | +                              rtol=1e-5, | 
|  | 654 | +                              atol=1e-5) | 
|  | 655 | + | 
|  | 656 | + | 
|  | 657 | +@multi_gpu_test(num_gpus=2) | 
|  | 658 | +def test_run_dp_sharded_mrope_vision_model_empty_input(): | 
|  | 659 | +    world_size = 2 | 
|  | 660 | +    mp.spawn( | 
|  | 661 | +        run_dp_sharded_mrope_vision_model_empty_input_worker, | 
|  | 662 | +        args=(world_size, get_open_port()), | 
|  | 663 | +        nprocs=world_size, | 
|  | 664 | +    ) | 
|  | 665 | + | 
|  | 666 | + | 
|  | 667 | +def run_dp_sharded_mrope_vision_model_empty_input_worker( | 
|  | 668 | +        local_rank: int, world_size: int, master_port: int): | 
|  | 669 | +    """Test run_dp_sharded_mrope_vision_model with empty input.""" | 
|  | 670 | +    # Set up distributed environment | 
|  | 671 | +    device = f"{current_platform.device_name}:{local_rank}" | 
|  | 672 | +    current_platform.set_device(device) | 
|  | 673 | +    torch.set_default_device(device) | 
|  | 674 | + | 
|  | 675 | +    update_environment_variables({ | 
|  | 676 | +        'RANK': str(local_rank), | 
|  | 677 | +        'LOCAL_RANK': str(local_rank), | 
|  | 678 | +        'WORLD_SIZE': str(world_size), | 
|  | 679 | +        'MASTER_ADDR': 'localhost', | 
|  | 680 | +        'MASTER_PORT': str(master_port), | 
|  | 681 | +    }) | 
|  | 682 | + | 
|  | 683 | +    init_distributed_environment() | 
|  | 684 | +    initialize_model_parallel(tensor_model_parallel_size=world_size) | 
|  | 685 | + | 
|  | 686 | +    # Create empty inputs | 
|  | 687 | +    pixel_values = torch.empty((0, 768)) | 
|  | 688 | +    grid_thw_list: list[list[int]] = [] | 
|  | 689 | + | 
|  | 690 | +    vision_model = SimpleMRopeVisionModel() | 
|  | 691 | + | 
|  | 692 | +    # Should handle empty input gracefully | 
|  | 693 | +    with torch.inference_mode(): | 
|  | 694 | +        output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, | 
|  | 695 | +                                                   grid_thw_list) | 
|  | 696 | + | 
|  | 697 | +    assert len(output) == 0 | 
|  | 698 | + | 
|  | 699 | + | 
|  | 700 | +@multi_gpu_test(num_gpus=4) | 
|  | 701 | +def test_run_dp_sharded_mrope_vision_model_uneven_load(): | 
|  | 702 | +    world_size = 4 | 
|  | 703 | +    mp.spawn( | 
|  | 704 | +        run_dp_sharded_mrope_vision_model_uneven_load_worker, | 
|  | 705 | +        args=(world_size, get_open_port()), | 
|  | 706 | +        nprocs=world_size, | 
|  | 707 | +    ) | 
|  | 708 | + | 
|  | 709 | + | 
|  | 710 | +def run_dp_sharded_mrope_vision_model_uneven_load_worker( | 
|  | 711 | +        local_rank: int, world_size: int, master_port: int): | 
|  | 712 | +    """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" | 
|  | 713 | +    # Set up distributed environment | 
|  | 714 | +    current_platform.seed_everything(123) | 
|  | 715 | +    device = f"{current_platform.device_name}:{local_rank}" | 
|  | 716 | +    current_platform.set_device(device) | 
|  | 717 | +    torch.set_default_device(device) | 
|  | 718 | + | 
|  | 719 | +    update_environment_variables({ | 
|  | 720 | +        'RANK': str(local_rank), | 
|  | 721 | +        'LOCAL_RANK': str(local_rank), | 
|  | 722 | +        'WORLD_SIZE': str(world_size), | 
|  | 723 | +        'MASTER_ADDR': 'localhost', | 
|  | 724 | +        'MASTER_PORT': str(master_port), | 
|  | 725 | +    }) | 
|  | 726 | + | 
|  | 727 | +    init_distributed_environment() | 
|  | 728 | +    initialize_model_parallel(tensor_model_parallel_size=world_size) | 
|  | 729 | + | 
|  | 730 | +    # Create images with very different sizes | 
|  | 731 | +    grid_thw_list = [ | 
|  | 732 | +        [1, 2, 2],  # Small: 4 patches | 
|  | 733 | +        [1, 8, 8],  # Large: 64 patches   | 
|  | 734 | +        [1, 3, 3],  # Medium: 9 patches | 
|  | 735 | +    ] | 
|  | 736 | + | 
|  | 737 | +    pixel_values_list = [] | 
|  | 738 | +    for grid_thw in grid_thw_list: | 
|  | 739 | +        num_patches = math.prod(grid_thw) | 
|  | 740 | +        image_pixels = torch.randn(num_patches, 768) | 
|  | 741 | +        pixel_values_list.append(image_pixels) | 
|  | 742 | + | 
|  | 743 | +    pixel_values = torch.cat(pixel_values_list, dim=0) | 
|  | 744 | +    vision_model = SimpleMRopeVisionModel() | 
|  | 745 | + | 
|  | 746 | +    # Should handle uneven distribution without errors | 
|  | 747 | +    with torch.inference_mode(): | 
|  | 748 | +        output_tuple = run_dp_sharded_mrope_vision_model( | 
|  | 749 | +            vision_model, pixel_values, grid_thw_list) | 
|  | 750 | + | 
|  | 751 | +    # Verify output shape is reasonable | 
|  | 752 | +    merge_factor = vision_model.spatial_merge_size**2 | 
|  | 753 | +    expected_output_patches = list( | 
|  | 754 | +        math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) | 
|  | 755 | + | 
|  | 756 | +    for i, output in enumerate(output_tuple): | 
|  | 757 | +        assert output.shape[0] == expected_output_patches[i] | 
|  | 758 | +        assert output.shape[1] == vision_model.out_hidden_size | 
|  | 759 | + | 
|  | 760 | + | 
|  | 761 | +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) | 
|  | 762 | +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): | 
|  | 763 | +    """Test SimpleMRopeVisionModel with different spatial merge sizes.""" | 
|  | 764 | +    device = current_platform.device_type | 
|  | 765 | + | 
|  | 766 | +    grid_thw_list = [[1, 4, 4], [1, 6, 6]]  # Two images | 
|  | 767 | +    pixel_values_list = [] | 
|  | 768 | + | 
|  | 769 | +    for grid_thw in grid_thw_list: | 
|  | 770 | +        num_patches = math.prod(grid_thw) | 
|  | 771 | +        image_pixels = torch.randn(num_patches, 768, device=device) | 
|  | 772 | +        pixel_values_list.append(image_pixels) | 
|  | 773 | + | 
|  | 774 | +    pixel_values = torch.cat(pixel_values_list, dim=0) | 
|  | 775 | +    vision_model = SimpleMRopeVisionModel( | 
|  | 776 | +        spatial_merge_size=spatial_merge_size).to(device) | 
|  | 777 | + | 
|  | 778 | +    with torch.inference_mode(): | 
|  | 779 | +        output = vision_model(pixel_values, grid_thw_list) | 
|  | 780 | + | 
|  | 781 | +    # Verify output dimensions based on spatial merging | 
|  | 782 | +    total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) | 
|  | 783 | +    merge_factor = spatial_merge_size**2 | 
|  | 784 | +    expected_output_patches = total_patches // merge_factor | 
|  | 785 | + | 
|  | 786 | +    assert output.shape[0] == expected_output_patches | 
|  | 787 | +    assert output.shape[1] == vision_model.out_hidden_size | 
0 commit comments