Skip to content

[Kernel] Triton implementation of causal-conv1d for Mamba-based models #18218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

thoangtrvn
Copy link

@thoangtrvn thoangtrvn commented May 15, 2025

This PR adds Triton-based causal-conv1d, making Mamba-based models in vLLM

  1. fully Triton-only backend.
  2. one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.

There are two kernels implemented

  • causal_conv1d_update_triton: which outperforms the corresponding CUDA kernel in handling decode-only requests
image [data benchmarking two kernels runtime by increasing the number of decode-only requests in a batch]
  • causal_conv1d_fn_triton: which outperform CUDA kernel in batch of mixed prefill/decode requests, e.g. 27x faster in the below microbenchmark with the same batch of mixed prefill/decode requests.
image

It also performs better than the CUDA-split pathway which was merged as PR #17146.
image
[data benchmarking runtime processing the same batch of mixed requests, first send the batch to the single Triton kernel, and then using CUDA-split pathway where requests are first separated, with prefill-only requests are sent to one kernel, and decode-only requests are sent to the second kernel]

ALGORITHMIC CHOICE OF TRITON KERNEL: Unlike CUDA kernel which is implemented with parallelism in 2D, i.e. along feature-dimension, and batch size; Triton kernel is implemented with parallelism in 3D, i.e. along also sequence-dimension. Also, the Triton kernels don't make any changes to the layout of the input data which is contiguous along the feature-dimension. Another key difference is that Triton kernels expect the conv-state to be contiguous along the feature-dimension, while in existing CUDA implementation, it expects the conv-state cache to be contiguous along the kernel-width (i.e. sequence-length) axis. Nevertheless, the two CUDA kernels are not compatible with the layout of conv-state cache, and therefore prevents the efficient processing in decode-only requests or mixed prefill/decode-requests.

Also, some other improvement in reducing overhead is incorporated.
Even though binary code generated from Triton is faster, the launch overhead is a known issue and is therefore need further optimization to get the E2E Triton-only Mamba models in vLLM performant. Here, we also incorporate such improvements by using a metadata that can be reused across layers.

In our benchmark on ShareGPT dataset which has short input prompt (a few hundreds of tokens)

  • default setting: generates short number of tokens (i.e. 256 tokens) CUDA-backed Bamba (ibm-ai-platform/Bamba-9B) is still faster; 10% slower (total token throughput) yet only 2% in output token throughput and 2% in TTFT.

  • generating 1024 tokens: Triton-backed Bamba is now faster with 5% faster on token throughput; and with 11% faster on TTFT. The benefit of faster Triton kernels now exceeds the overall costs of Triton launch overhead.

image

In the longer context length and/or longer number of generated tokens, Triton-only Mamba-based model is expected to be better than CUDA-split approach. However, the PR maintains the existing CUDA pathway as the default one until it is adopted by vLLM maintainers. Currently, the code is added as an optional pathway to the CUDA-split via VLLM_USE_TRITON_CONV1D environment variable set to 1.

This is one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.

Test code is also added

tmhoangt added 3 commits May 15, 2025 09:32
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

tmhoangt added 7 commits May 15, 2025 17:52
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Copy link

mergify bot commented May 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @thoangtrvn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 27, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you report some lm_eval scores running gsm8k, as well as make sure it runs correctly without --enforce-eager?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the changes to bamba.py, granitemoehybrid.py, mamba2.py and zamba2.py are pretty spurious. Could you revert those?

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Mamba2's metadata, there are setting that get updated once (at model initialization), and there are settings that get updated at every input (while keeping the same across layers).
Adding self.mamba2_metadata provides a solution to reuse updated-once data. If you don't like this level of optimization, please let me know @tlrmchlsmth. This is optional. I can revert the changes to the other models, and keep the change only on bamba.py.

Comment on lines 39 to 40
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None)
if path is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add an environment variable, this should be handled in vllm/envs.py.

However, generally we should avoid adding an environment variable where possible. Which cases should we be using the triton conv1d kernels vs CUDA?

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, the testing shows CUDA lower overhead is better at short context length (~300-400 tokens, and short output length); other than that, Triton-kernel is better. The e2e slower doesn't come from the kernel itself, but the overhead launch in overall. I just want to maintain the two pathways before a final decision is made, based on other testing outside what I have done.

I'll provide more details of test cases in the PR threads.

While in theory, some Triton launch overhead can be reduced using Triton JIT cache mechanism, it is not tested here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done update the code

[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test more than just seqlen 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, I can add more to the test code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 143 to 148
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures that all the groups corresponding to a head shard is placed
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures the the all the groups corresponding to a head shard is placed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please revert the the back to that

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 18 to 23
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32,
device='cpu') # tracking which seq-idx the Triton program is handling
token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, device='cpu'
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these stateful global variables? Would it be better for these to go in the mamba2_metadata instead?

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain: the parallelism of the causal-conv1d kernel (prefill stage) is 3D: batch, feature, and seqlen dimensions. It means that each Triton program can handle a group of token in a sequence. The information about token range (start/stop) at each sequence, is tracked by these two tensors. In vLLM, the tensors are generated as part of the metadata construct, which means at e2e runtime, it doesn't use these.
However, for kernel-level runtime, e.g. microbenchmarking or kernel testing; they needs to be provided. Here, we have 3 choices:

  • created a metadata object just like vLLM e2e setting
  • when metadata is not provided, either we declare inside the function that use it (which create the tensor each time a kernel get invoked), or declare at the module-level (created once at module loading - safe some overhead of tensor allocation).

Please let me know what you think it should be revised @tlrmchlsmth .

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

migrated code change based on your feedback @tlrmchlsmth

Comment on lines 27 to 50
num_cache_lines: Optional[int] = None
stride_istate_seq: Optional[int] = None
stride_istate_dim: Optional[int] = None
stride_istate_token: Optional[int] = None
seqlens: Optional[np.ndarray] = None
padded_batch: Optional[int] = None
nums_dict: Optional[dict] = None
is_channel_last: bool = True
stride_w_dim: Optional[int] = None
stride_w_width: Optional[int] = None
width: Optional[int] = None
np2_statelen: Optional[int] = None
stride_x_seq: Optional[int] = 0
stride_x_dim: Optional[int] = None
stride_x_token: Optional[int] = None
dim: Optional[int] = None
cu_seqlen: Optional[int] = None
out: Optional[torch.Tensor] = None
stride_o_seq: Optional[int] = 0
stride_o_dim: Optional[int] = None
stride_o_token: Optional[int] = None
MAX_NUM_PROGRAMS: int = 1024
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of stuff in here, and it's not really clear what most of it is for. At first glance it seems like most of this should be accessed on the fly instead of stored in the metadata here. Could you take a stab at cleaning this up?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The information here is reused across Mamba layers, even the stride call would trigger Torch calls, which triggers an unnecessary overhead. I can adds a description as needed. Please let me know what you think @tlrmchlsmth

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done updated code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most of this should be removed. For CPU overheads in decode, we can rely on CUDA graphs and for prefill they are amortized

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running some test to revert this.

DONE

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth : done the update based on your feedback.

Comment on lines 110 to 115
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing
# TODO: maybe revert back to the original code (below) if above no longer holds
# has_initial_states = attn_metadata.context_lens_tensor > 0
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states]
# mamba_cache_params.ssm_state[zero_init_indices] = 0
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rely on batch reordering and require that it be used for this implementation.

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from a previous PR (cuda-split) implies this assumption, this PR doesn't have this assumption which makes is more suitable for vLLM v1 design, the comment I added here is to clarify the code path from the previous PR. I can remove the comment as needed. Please let me know @tlrmchlsmth

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please remove the comment.

We can rely on batch reordering even in vLLM V1, so this is a non issue

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 388 to 394
def is_conv_in_Triton(self):
import os
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None)
if path is not None:
print("mamba_mixer2 - VLLM_USE_TRITON_CONV1D")
return True
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things here:

  1. Remove the print before landing
  2. env variables should be defined in vllm/envs.py
  3. The function would be better-named use_triton_causal_conv_1d (more descriptive & proper capitalization)
  4. We should work hard to avoid proliferation of environment variables. Could you come up with a reliable heuristic to choose between this triton implementation and the CUDA kernel instead of adding an env that exposes complexity to the end-user?

Copy link
Author

@thoangtrvn thoangtrvn Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I update the code accordingly. The Triton version use the layout of conv_state cache that is the same layout with input tensor hidden_states_B_C, i.e. contiguous along feature-dimension. The CUDA version in vLLM use a different layout of conv_state cache that is contiguous along sequence-dimension. So, there is no good way to switch from one to another in the same session. It's the choice to be made, based on the expected workload. The benefit of Mamba-based model is in long context length and long output response, which we shows that Triton version would take over. Also, the Triton kernel follows vLLM v1 design, allowing prefill requests and decode requests to be mixed.

RESULT: long-context latency measurement (compared with Llama)

python benchmarks/benchmark_latency.py --model /net/storage149/autofs/css22/nmg/models/hf/meta-llama/Llama-3.1-8B-Instruct/main --input-len=131072 --output-len=1 --batch-size=1 --max_num_batched_tokens=2048

Avg latency: 11.214325266804856 seconds
10% percentile latency: 11.202042526123114 seconds
25% percentile latency: 11.206939334078925 seconds
50% percentile latency: 11.212064623483457 seconds
75% percentile latency: 11.220630767958937 seconds
90% percentile latency: 11.2278619370074 seconds
99% percentile latency: 11.24528882814222 seconds
  • Main branch
# Default (max_num_batched_tokens=2048)
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1
Avg latency: 6.231618080474436 seconds
10% percentile latency: 6.204561746446416 seconds
25% percentile latency: 6.216710253240308 seconds
50% percentile latency: 6.219352831016295 seconds
75% percentile latency: 6.223808606999228 seconds
90% percentile latency: 6.227424801979214 seconds
99% percentile latency: 6.519547982601217 seconds
  • Current PR:
export VLLM_USE_TRITON_CONV1D="1"
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1

Avg latency: 5.757278195097267 seconds    
10% percentile latency: 5.734804188809358 seconds 
25% percentile latency: 5.739403567742556 seconds  
50% percentile latency: 5.743007940007374 seconds  
75% percentile latency: 5.748229099262971 seconds   
90% percentile latency: 5.751799210254103 seconds 
99% percentile latency: 6.068630096325651 seconds 

thoangtrvn and others added 3 commits June 2, 2025 12:05
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@thoangtrvn
Copy link
Author

thoangtrvn commented Jun 4, 2025

Could you report some lm_eval scores running gsm8k, as well as make sure it runs correctly without --enforce-eager?

GSM8K RESULT


#ibm-ai-platform/Bamba-9B
# (current) CUDA-SPLIT code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2335|±  |0.0117|
|     |       |strict-match    |     5|exact_match|↑  |0.3442|±  |0.0131|
# PR code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2456|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.3495|±  |0.0131|

#Zyphra/Zamba2-2.7B
# (current) CUDA-SPLIT code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5330|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.5466|±  |0.0137|
# PR code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5330|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.5466|±  |0.0137|



#mistralai/Mamba-Codestral-7B-v0.1
# (current) CUDA-SPLIT code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4647|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.4549|±  |0.0137|
# PR code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4655|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.4526|±  |0.0137|

COMMAND TO RUN:

echo 'ibm-ai-platform/Bamba-9B'
lm_eval --model vllm     --model_args pretrained=ibm-ai-platform/Bamba-9B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto  --cache_requests true --tasks gsm8k


echo "DONE RUN (CUDA-SPLIT)"

export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm     --model_args pretrained=ibm-ai-platform/Bamba-9B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto  --cache_requests true --tasks gsm8k


echo "DONE RUN (PR)"

echo 'Zyphra/Zamba2-2.7B'
lm_eval --model vllm     --model_args pretrained=Zyphra/Zamba2-2.7B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (CUDA-SPLIT)"

export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm     --model_args pretrained=Zyphra/Zamba2-2.7B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (PR)"

echo 'Mamba-Codestral-7B-v0.1'
lm_eval --model vllm     --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (CUDA-SPLIT)"

export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm     --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (PR)"

tmhoangt and others added 4 commits June 4, 2025 10:51
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@mergify mergify bot removed the needs-rebase label Jun 4, 2025
tmhoangt added 2 commits June 5, 2025 11:03
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@tlrmchlsmth tlrmchlsmth changed the title add causal-conv1d in Triton and integrate into vLLM with test code [Kernel] Triton implementation of causal-conv1d for Mamba-based models Jun 9, 2025
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you reduce the amount of time spend in this test? It's taking 102s, and we should try to keep unit test time under control

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also could you remove the _vllm suffix?

@tlrmchlsmth
Copy link
Collaborator

default setting: generates short number of tokens (i.e. 256 tokens) CUDA-backed Bamba (ibm-ai-platform/Bamba-9B) is still faster; 10% slower (total token throughput) yet only 2% in output token throughput and 2% in TTFT.

The table shown shows a drop from 3628 to 3557 in total tok/s, which is only a 2% slowdown. Am I missing something?

For simplicity, it would be better to remove the causal_conv1d CUDA kernel and use the triton kernel in this PR exclusively if the slowdowns aren't significant. If it's a 2% drop in the worst case I think this may be a reasonable thing to do.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changes compatible with CUDA graphs?

Comment on lines 549 to 550
def test_causal_conv1d_varlen_vllm(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename this to test_causal_conv1d_varlen? We don't need to clarify that it's for vllm, since this is in the vllm codebase

# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also could you remove the _vllm suffix?

conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if use_triton_causal_conv_1d:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is too large and monolithic and should be refactored before landing. The if statements are too deeply nested and there is a lot of duplicated logic as well.

(Removing the CUDA implementations of causal_conv_1d, would help with this)

@@ -22,6 +27,14 @@ class Mamba2Metadata:
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor

seqlens: Optional[np.ndarray] = None
nums_dict: Optional[dict] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is nums_dict? This should be documented.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a batch of requests, a prefill request can be processed in parallel where each Triton program handles BLOCK_M tokens. Depending on the choice of BLOCK_M, the values in batch_ptr and token_chunk_offset_ptr can be different. The choice of BLOCK_M can be different at different inputs, or different hardware. Currently, BLOCK_M is chosen as 8 and is the same across all inputs which is a good choice to avoid the overhead in Triton autotune.

nums_dict[BLOCK_M] = {batch_ptr, token_chunk_offset_ptr}

I added the documents accordingly.

seqlens: Optional[np.ndarray] = None
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
out: Optional[torch.Tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what this is?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Triton causal-conv1d (prefill or mixed prefill/decode) kernel process input x using conv_state. While conv_state update must be in-place, the output is written to out tensor rather than writing to x to avoid race condition - as each Triton program handles one segment of the request (unlike CUDA kernel where one thread block handles one full request). out is reused across all layers .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since out isn't metadata, please remove it from Mamba2Metadata and treat it like a normal tensor

Comment on lines 113 to 114
# keeping flags for both prefill and decode causal_conv1d varlen
# [batch,]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify this comment? It's not clear to me what [batch,] means in this context.

Comment on lines 115 to 125
if use_triton_causal_conv_1d:
has_initial_states = attn_metadata.context_lens_tensor > 0
prep_initial_states = torch.any(
has_initial_states[:num_prefills]).item()
else:
has_initial_states = (
attn_metadata.context_lens_tensor[:num_prefills] > 0)
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
prep_initial_states = torch.any(has_initial_states).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be different?

tmhoangt added 3 commits June 9, 2025 22:18
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@mergify mergify bot added the ci/build label Jun 10, 2025
Copy link
Author

@thoangtrvn thoangtrvn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added my comments.

Comment on lines 39 to 40
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None)
if path is not None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done update the code

Comment on lines 27 to 50
num_cache_lines: Optional[int] = None
stride_istate_seq: Optional[int] = None
stride_istate_dim: Optional[int] = None
stride_istate_token: Optional[int] = None
seqlens: Optional[np.ndarray] = None
padded_batch: Optional[int] = None
nums_dict: Optional[dict] = None
is_channel_last: bool = True
stride_w_dim: Optional[int] = None
stride_w_width: Optional[int] = None
width: Optional[int] = None
np2_statelen: Optional[int] = None
stride_x_seq: Optional[int] = 0
stride_x_dim: Optional[int] = None
stride_x_token: Optional[int] = None
dim: Optional[int] = None
cu_seqlen: Optional[int] = None
out: Optional[torch.Tensor] = None
stride_o_seq: Optional[int] = 0
stride_o_dim: Optional[int] = None
stride_o_token: Optional[int] = None
MAX_NUM_PROGRAMS: int = 1024
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done updated code.

Comment on lines 143 to 148
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures that all the groups corresponding to a head shard is placed
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures the the all the groups corresponding to a head shard is placed
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 110 to 115
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing
# TODO: maybe revert back to the original code (below) if above no longer holds
# has_initial_states = attn_metadata.context_lens_tensor > 0
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states]
# mamba_cache_params.ssm_state[zero_init_indices] = 0
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed comment

[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

vllm/envs.py Outdated
@@ -125,7 +125,6 @@
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_USE_TRITON_CONV1D: bool = False
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we no longer need this environment variable after removing CUDA kernel

metadata=mamba2_metadata,
query_start_loc=attn_metadata.query_start_loc).transpose(
0, 1)[:seq_len]
# causal_conv1d_fn deals with both prefill and decode if input
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now follows the original code pathway. There is no need to split requests into prefill requests and decode requests for causal-conv1d kernel. This reverts the change made by a previous PR #17146. However, the other SSD kernels are still not designed for mixed prefill/decode-requests and therefore the separation is still present (as you see at the end)

)
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_p, hidden_states_d = torch.split(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SSD kernels are still not designed for mixed prefill/decode-requests and therefore the separation is still present here (a contribution from another PR). Hopefully, a better algorithmic design for vLLM can address this in the future.

weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the wrapper to CUDA kernel.

@triton.jit()
def _causal_conv1d_update_kernel(
# Pointers to matrices
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothing change here, I move the causal_conv1d_fwd kernel to the front of the causal_conv1d_update, and it triggers a messy diff.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks better! I left a few more comments.

I think this will land after #19327 so we'll need to handle the attn metadata introduced by that PR as well.

Comment on lines 220 to 227
channel_last = True
if not channel_last:
x = torch.randn(padded_batch_size,
dim,
seqlen,
device=device,
dtype=itype)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this?

Comment on lines 250 to 256
if not channel_last:
conv_state = torch.randn(total_entries,
dim,
width - 1,
device=device,
dtype=itype)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines 328 to 332
channel_last = True
if not channel_last:
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, remove

Comment on lines 344 to 350
if not channel_last:
final_states = torch.randn(total_entries,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines 341 to 295
@pytest.mark.parametrize(
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('seqlen', [8, 2049, 4096])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks good to me - the previous number of cases seems like overkill. Could you add a couple more non-power-of-two problem sizes to be safe.

seqlens: Optional[np.ndarray] = None
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
out: Optional[torch.Tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since out isn't metadata, please remove it from Mamba2Metadata and treat it like a normal tensor

tmhoangt added 3 commits June 10, 2025 19:39
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@thoangtrvn
Copy link
Author

@tlrmchlsmth : I merged the recent changes in main, and adopt the contribution from PR #19327. Please let me know if there is something else I should revise as well. To remove completely mamba2_metadata.py and use only mamba_attn.py, I guess it's better to be in a separate PR.

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@tlrmchlsmth
Copy link
Collaborator

basic-correctness-test errors are related

circular import issue

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@thoangtrvn
Copy link
Author

basic-correctness-test errors are related

fixed that circular import issue.

Copy link

mergify bot commented Jun 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @thoangtrvn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 28, 2025
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@mergify mergify bot removed the needs-rebase label Jun 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants