Skip to content

Pass s_aux through flash_attn_with_kvcache #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2025

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Aug 8, 2025

Trying to run the hybrid models test is giving me grief without this change:

python -m pytest models/language/generation/test_hybrid.py::test_models[5-64-ai21labs/Jamba-tiny-dev] -xsv

produces:

models/language/generation/test_hybrid.py:97: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
conftest.py:788: in __init__
    self.llm = LLM(
../vllm/entrypoints/llm.py:277: in __init__
    self.llm_engine = LLMEngine.from_engine_args(
../vllm/engine/llm_engine.py:494: in from_engine_args
    return engine_cls.from_vllm_config(
../vllm/engine/llm_engine.py:470: in from_vllm_config
    return cls(
../vllm/engine/llm_engine.py:263: in __init__
    self._initialize_kv_caches()
../vllm/engine/llm_engine.py:419: in _initialize_kv_caches
    self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
../vllm/executor/executor_base.py:125: in initialize_cache
    self.collective_rpc("initialize_cache",
../vllm/executor/uniproc_executor.py:58: in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../vllm/utils/__init__.py:2959: in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
../vllm/worker/worker.py:336: in initialize_cache
    self._warm_up_model()
../vllm/worker/worker.py:387: in _warm_up_model
    self.model_runner.capture_model(self.gpu_cache)
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
../vllm/worker/model_runner.py:1526: in capture_model
    graph_runner.capture(**capture_inputs)
../vllm/worker/model_runner.py:1928: in capture
    self.model(
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1751: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1762: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../vllm/model_executor/models/jamba.py:530: in forward
    hidden_states = self.model(input_ids, positions, mamba_cache_params,
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1751: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1762: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../vllm/model_executor/models/jamba.py:357: in forward
    hidden_states, residual = layer(
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1751: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1762: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../vllm/model_executor/models/jamba.py:261: in forward
    hidden_states = self.self_attention(
../vllm/model_executor/models/jamba.py:243: in self_attention
    attn_output = self.attn(q, k, v)
                  ^^^^^^^^^^^^^^^^^^
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1751: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/nn/modules/module.py:1762: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../vllm/attention/layer.py:272: in forward
    torch.ops.vllm.unified_attention_with_output(
../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/_ops.py:1158: in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../vllm/attention/layer.py:488: in unified_attention_with_output
    self.impl.forward(self,
../vllm/attention/backends/flash_attn.py:905: in forward
    flash_attn_with_kvcache(
../vllm/vllm_flash_attn/flash_attn_interface.py:446: in flash_attn_with_kvcache
    out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <OpOverloadPacket(op='_vllm_fa3_C.fwd')>
args = (tensor([[[[-0.7305, -0.5859, -0.9258,  ..., -0.3574, -1.3125,  0.3242],
          [-1.1641, -2.5781, -0.1338,  ..., -...0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0',
       dtype=torch.bfloat16), None, None, None, ...)
kwargs = {}

    def __call__(self, /, *args, **kwargs):
        # overloading __call__ to ensure torch.ops.foo.bar()
        # is still callable from JIT
        # We save the function ptr as the `op` attribute on
        # OpOverloadPacket to access it here.
    
        # Directly calling OverloadPacket goes into C++, which will check
        # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
        # intercept it here and call TorchBindOpverload instead.
        if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
            return _call_overload_packet_from_python(self, args, kwargs)
>       return self._op(*args, **(kwargs or {}))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       RuntimeError: _vllm_fa3_C::fwd() is missing value for argument 's_aux'. Declaration: _vllm_fa3_C::fwd(Tensor($0! -> ) q, Tensor k, Tensor v, Tensor? k_new, Tensor? v_new, Tensor? q_v, Tensor($1! -> )? out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? cu_seqlens_k_new, Tensor? seqused_q, Tensor? seqused_k, int? max_seqlen_q, int? max_seqlen_k, Tensor? page_table, Tensor? kv_batch_idx, Tensor? leftpad_k, Tensor? rotary_cos, Tensor? rotary_sin, Tensor? seqlens_rotary, Tensor? q_descale, Tensor? k_descale, Tensor? v_descale, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool is_rotary_interleaved, Tensor? scheduler_metadata, int num_splits, bool? pack_gqa, int sm_margin, Tensor? s_aux) -> Tensor[]

../../miniforge3/envs/new-env/lib/python3.12/site-packages/torch/_ops.py:1158: RuntimeError

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Aug 8, 2025

cc @LucasWilkinson

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context fixes FA3 + V0; LGTM

@LucasWilkinson LucasWilkinson merged commit 93cf5a0 into vllm-project:main Aug 8, 2025
1 check passed
@Jialin
Copy link

Jialin commented Aug 11, 2025

For context fixes FA3 + V0; LGTM

Internally, we're patching this PR as fix forward to mitigate the broken truck.
@LucasWilkinson do you think if we should add a unit test to avoid similar brokerage in the future? Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants