Skip to content

[CI Failure]: Samplers Test - samplers/test_beam_search.py::test_beam_search_passes_multimodal_data #19736

Closed
@mgoin

Description

@mgoin

Name of failing test

samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half]

Basic information

  • Flaky test
  • Can reproduce locally
  • Caused by external libraries (e.g. bug in transformers)

🧪 Describe the failing test

It seems the issue is because we are now passing empty lists to _flatten_embeddings

FAILED samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half] - RuntimeError: torch.cat(): expected a non-empty list of Tensors

Full output:

pytest -s -v "samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half]"
INFO 06-17 09:19:56 [__init__.py:244] Automatically detected platform cuda.
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
============================================================================================ test session starts =============================================================================================
platform linux -- Python 3.12.4, pytest-8.3.3, pluggy-1.5.0 -- /home/mgoin/venvs/vllm/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/home/mgoin/code/vllm/tests/.hypothesis/examples'))
rootdir: /home/mgoin/code/vllm
configfile: pyproject.toml
plugins: forked-1.6.0, subtests-0.14.1, asyncio-0.24.0, shard-0.1.2, buildkite-test-collector-0.1.9, timeout-2.3.1, schemathesis-3.39.15, anyio-4.6.2.post1, mock-3.14.0, hypothesis-6.131.0, rerunfailures-14.0
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 1 item                                                                                                                                                                                             
Running 1 items in this shard: tests/samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half]

samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half] WARNING 06-17 09:19:58 [config.py:3273] Casting torch.bfloat16 to torch.float16.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  4.90it/s]
The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
INFO 06-17 09:20:16 [config.py:831] This model supports multiple tasks: {'generate', 'score', 'classify', 'reward', 'embed'}. Defaulting to 'generate'.
WARNING 06-17 09:20:16 [config.py:3273] Casting torch.bfloat16 to torch.float16.
INFO 06-17 09:20:16 [config.py:1444] Using max model len 1024
INFO 06-17 09:20:16 [llm_engine.py:230] Initializing a V0 LLM engine (v0.9.1.dev287+g89b1388d8) with config: model='Qwen/Qwen2-Audio-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2-Audio-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen2-Audio-7B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=True, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":false,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":256,"local_cache_dir":null}, use_cached_outputs=False, 
INFO 06-17 09:20:18 [cuda.py:336] Using Flash Attention backend.
INFO 06-17 09:20:18 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 06-17 09:20:18 [model_runner.py:1171] Starting to load model Qwen/Qwen2-Audio-7B-Instruct...
INFO 06-17 09:20:19 [weight_utils.py:292] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:02,  1.59it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:01<00:02,  1.49it/s]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:02<00:01,  1.44it/s]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:02<00:00,  1.48it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  1.88it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  1.68it/s]

INFO 06-17 09:20:22 [default_loader.py:272] Loading weights took 3.00 seconds
INFO 06-17 09:20:22 [model_runner.py:1203] Model loading took 15.6455 GiB and 3.447517 seconds
INFO 06-17 09:20:25 [worker.py:294] Memory profiling takes 2.68 seconds
INFO 06-17 09:20:25 [worker.py:294] the current vLLM instance can use total_gpu_memory (79.19GiB) x gpu_memory_utilization (0.90) = 71.27GiB
INFO 06-17 09:20:25 [worker.py:294] model weights take 15.65GiB; non_torch_memory takes 0.00GiB; PyTorch activation peak memory takes 0.51GiB; the rest of the memory reserved for KV Cache is 55.11GiB.
INFO 06-17 09:20:25 [executor_base.py:113] # cuda blocks: 7054, # CPU blocks: 512
INFO 06-17 09:20:25 [executor_base.py:118] Maximum concurrency for 1024 tokens per request: 110.22x
INFO 06-17 09:20:27 [model_runner.py:1513] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes:   0%|                                                                                                                     | 0/35 [00:00<?, ?it/s]
FAILED

================================================================================== FAILURES ==================================================================================
__________________________________________________________ test_beam_search_passes_multimodal_data[False-2-64-half] __________________________________________________________

hf_runner = <class 'tests.conftest.HfRunner'>, vllm_runner = <class 'tests.conftest.VllmRunner'>, dtype = 'half', max_tokens = 64, beam_width = 2

    @pytest.mark.parametrize("dtype", ["half"])
    @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
    @pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
    def test_beam_search_passes_multimodal_data(
        hf_runner,
        vllm_runner,
        dtype: str,
        max_tokens: int,
        beam_width: int,
    ) -> None:
        """Ensure that beam search passes multimodal data through correctly."""
        # NOTE - this test is primarily to check that mm data is passed to beams
        # correctly. As such, we just need to check one extra modality to make
        # sure things pass through properly.
        audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
        model = "Qwen/Qwen2-Audio-7B-Instruct"
        audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
        prompts = [
            f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n"  #noqa: E501
        ]
    
        with hf_runner(model, dtype=dtype,
                       auto_cls=AutoModelForSeq2SeqLM) as hf_model:
            audio_token_id = hf_model.config.audio_token_index
            eos_token_id = hf_model.tokenizer.eos_token_id  # <|im_end|>
            hf_outputs = hf_model.generate_beam_search(
                prompts,
                beam_width=beam_width,
                max_tokens=max_tokens,
                audios=audios,
            )
    
>       with vllm_runner(model, dtype=dtype) as vllm_model:

samplers/test_beam_search.py:102: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
conftest.py:782: in __init__
    self.model = LLM(
../vllm/entrypoints/llm.py:262: in __init__
    self.llm_engine = LLMEngine.from_engine_args(
../vllm/engine/llm_engine.py:501: in from_engine_args
    return engine_cls.from_vllm_config(
../vllm/engine/llm_engine.py:477: in from_vllm_config
    return cls(
../vllm/engine/llm_engine.py:268: in __init__
    self._initialize_kv_caches()
../vllm/engine/llm_engine.py:426: in _initialize_kv_caches
    self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
../vllm/executor/executor_base.py:124: in initialize_cache
    self.collective_rpc("initialize_cache",
../vllm/executor/uniproc_executor.py:57: in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
../vllm/utils.py:2690: in run_method
    return func(*args, **kwargs)
../vllm/worker/worker.py:335: in initialize_cache
    self._warm_up_model()
../vllm/worker/worker.py:365: in _warm_up_model
    self.model_runner.capture_model(self.gpu_cache)
../../../venvs/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
../vllm/worker/model_runner.py:1658: in capture_model
    graph_runner.capture(**capture_inputs)
../vllm/worker/model_runner.py:2059: in capture
    self.model(
../../../venvs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py:1751: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../../venvs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py:1762: in _call_impl
    return forward_call(*args, **kwargs)
../vllm/model_executor/models/qwen2_audio.py:389: in forward
    inputs_embeds = self.get_input_embeddings(input_ids,
../vllm/model_executor/models/qwen2_audio.py:368: in get_input_embeddings
    inputs_embeds = merge_multimodal_embeddings(
../vllm/model_executor/models/utils.py:498: in merge_multimodal_embeddings
    return _merge_multimodal_embeddings(
../vllm/model_executor/models/utils.py:411: in _merge_multimodal_embeddings
    flattened = _flatten_embeddings(multimodal_embeddings)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

embeddings = []

    def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
        """
        Recursively flattens and concatenates NestedTensors on all but the last
        dimension.
        """
    
        if isinstance(embeddings, torch.Tensor):
            # Flatten all but the last dimension.
            return embeddings.flatten(0, -2)
    
>       return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
E       RuntimeError: torch.cat(): expected a non-empty list of Tensors

../vllm/model_executor/models/utils.py:363: RuntimeError
============================================================================== warnings summary ==============================================================================
../../../venvs/vllm/lib/python3.12/site-packages/schemathesis/generation/coverage.py:305
  /home/mgoin/venvs/vllm/lib/python3.12/site-packages/schemathesis/generation/coverage.py:305: DeprecationWarning: jsonschema.exceptions.RefResolutionError is deprecated as of version 4.18.0. If you wish to catch potential reference resolution errors, directly catch referencing.exceptions.Unresolvable.
    ref_error: type[Exception] = jsonschema.RefResolutionError,

tests/samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half]
  /home/mgoin/venvs/vllm/lib/python3.12/site-packages/librosa/core/intervals.py:15: DeprecationWarning: path is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
    with resources.path("librosa.core", "intervals.msgpack") as imsgpack:

tests/samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half]
  /home/mgoin/venvs/vllm/lib/python3.12/site-packages/audioread/rawread.py:16: DeprecationWarning: 'aifc' is deprecated and slated for removal in Python 3.13
    import aifc

tests/samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half]
  /home/mgoin/venvs/vllm/lib/python3.12/site-packages/audioread/rawread.py:17: DeprecationWarning: 'audioop' is deprecated and slated for removal in Python 3.13
    import audioop

tests/samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half]
  /home/mgoin/venvs/vllm/lib/python3.12/site-packages/audioread/rawread.py:19: DeprecationWarning: 'sunau' is deprecated and slated for removal in Python 3.13
    import sunau

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================== short test summary info ===========================================================================
FAILED samplers/test_beam_search.py::test_beam_search_passes_multimodal_data[False-2-64-half] - RuntimeError: torch.cat(): expected a non-empty list of Tensors
======================================================================= 1 failed, 5 warnings in 29.93s =======================================================================

📝 History of failing test

This was introduced by #19446

https://buildkite.com/organizations/vllm/analytics/suites/ci-1/tests/1f2a99b2-fbc9-89fc-a08f-c4d431429893?period=7days&tags=scm.branch%3Amain
Image

It was not caught because the test wasn't triggered by the change I believe.

CC List.

@russellb

Metadata

Metadata

Assignees

No one assigned

    Labels

    ci-failureIssue about an unexpected test failure in CI

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions