-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Kernel] Triton implementation of causal-conv1d for Mamba-based models #18218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you report some lm_eval scores running gsm8k, as well as make sure it runs correctly without --enforce-eager
?
vllm/model_executor/models/bamba.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the changes to bamba.py
, granitemoehybrid.py
, mamba2.py
and zamba2.py
are pretty spurious. Could you revert those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the Mamba2's metadata, there are setting that get updated once (at model initialization), and there are settings that get updated at every input (while keeping the same across layers).
Adding self.mamba2_metadata
provides a solution to reuse updated-once
data. If you don't like this level of optimization, please let me know @tlrmchlsmth. This is optional. I can revert the changes to the other models, and keep the change only on bamba.py
.
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None) | ||
if path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add an environment variable, this should be handled in vllm/envs.py
.
However, generally we should avoid adding an environment variable where possible. Which cases should we be using the triton conv1d kernels vs CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far, the testing shows CUDA lower overhead is better at short context length (~300-400 tokens, and short output length); other than that, Triton-kernel is better. The e2e slower doesn't come from the kernel itself, but the overhead launch in overall. I just want to maintain the two pathways before a final decision is made, based on other testing outside what I have done.
I'll provide more details of test cases in the PR threads.
While in theory, some Triton launch overhead can be reduced using Triton JIT cache mechanism, it is not tested here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done update the code
[torch.float32, torch.float16, torch.bfloat16]) | ||
@pytest.mark.parametrize("silu_activation", [False, True]) | ||
@pytest.mark.parametrize("has_bias", [False, True]) | ||
@pytest.mark.parametrize("seqlen", [1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test more than just seqlen 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, I can add more to the test code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
"""Create a weight loader for mamba v2. This ensures that the projections | ||
are correctly sharded so that they can be split into x, B, C. It also | ||
ensures that all the groups corresponding to a head shard is placed | ||
"""Create a weight loader for mamba v2. This ensures that the projections | ||
are correctly sharded so that they can be split into x, B, C. It also | ||
ensures the the all the groups corresponding to a head shard is placed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please revert the the
back to that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
batch_ptr = torch.full( | ||
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, | ||
device='cpu') # tracking which seq-idx the Triton program is handling | ||
token_chunk_offset_ptr = torch.full( | ||
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, device='cpu' | ||
) # tracking BLOCK_M-based index in the sequence the Triton program is handling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these stateful global variables? Would it be better for these to go in the mamba2_metadata instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain: the parallelism of the causal-conv1d kernel (prefill stage) is 3D: batch, feature, and seqlen dimensions. It means that each Triton program can handle a group of token in a sequence. The information about token range (start/stop) at each sequence, is tracked by these two tensors. In vLLM, the tensors are generated as part of the metadata construct, which means at e2e runtime, it doesn't use these.
However, for kernel-level runtime, e.g. microbenchmarking or kernel testing; they needs to be provided. Here, we have 3 choices:
- created a metadata object just like vLLM e2e setting
- when metadata is not provided, either we declare inside the function that use it (which create the tensor each time a kernel get invoked), or declare at the module-level (created once at module loading - safe some overhead of tensor allocation).
Please let me know what you think it should be revised @tlrmchlsmth .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
migrated code change based on your feedback @tlrmchlsmth
num_cache_lines: Optional[int] = None | ||
stride_istate_seq: Optional[int] = None | ||
stride_istate_dim: Optional[int] = None | ||
stride_istate_token: Optional[int] = None | ||
seqlens: Optional[np.ndarray] = None | ||
padded_batch: Optional[int] = None | ||
nums_dict: Optional[dict] = None | ||
is_channel_last: bool = True | ||
stride_w_dim: Optional[int] = None | ||
stride_w_width: Optional[int] = None | ||
width: Optional[int] = None | ||
np2_statelen: Optional[int] = None | ||
stride_x_seq: Optional[int] = 0 | ||
stride_x_dim: Optional[int] = None | ||
stride_x_token: Optional[int] = None | ||
dim: Optional[int] = None | ||
cu_seqlen: Optional[int] = None | ||
out: Optional[torch.Tensor] = None | ||
stride_o_seq: Optional[int] = 0 | ||
stride_o_dim: Optional[int] = None | ||
stride_o_token: Optional[int] = None | ||
MAX_NUM_PROGRAMS: int = 1024 | ||
batch_ptr: Optional[torch.tensor] = None | ||
token_chunk_offset_ptr: Optional[torch.tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a lot of stuff in here, and it's not really clear what most of it is for. At first glance it seems like most of this should be accessed on the fly instead of stored in the metadata here. Could you take a stab at cleaning this up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The information here is reused across Mamba layers, even the stride call would trigger Torch calls, which triggers an unnecessary overhead. I can adds a description as needed. Please let me know what you think @tlrmchlsmth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done updated code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think most of this should be removed. For CPU overheads in decode, we can rely on CUDA graphs and for prefill they are amortized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm running some test to revert this.
DONE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tlrmchlsmth : done the update based on your feedback.
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing | ||
# TODO: maybe revert back to the original code (below) if above no longer holds | ||
# has_initial_states = attn_metadata.context_lens_tensor > 0 | ||
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states] | ||
# mamba_cache_params.ssm_state[zero_init_indices] = 0 | ||
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can rely on batch reordering and require that it be used for this implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from a previous PR (cuda-split) implies this assumption, this PR doesn't have this assumption which makes is more suitable for vLLM v1 design, the comment I added here is to clarify the code path from the previous PR. I can remove the comment as needed. Please let me know @tlrmchlsmth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please remove the comment.
We can rely on batch reordering even in vLLM V1, so this is a non issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def is_conv_in_Triton(self): | ||
import os | ||
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None) | ||
if path is not None: | ||
print("mamba_mixer2 - VLLM_USE_TRITON_CONV1D") | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things here:
- Remove the print before landing
- env variables should be defined in
vllm/envs.py
- The function would be better-named
use_triton_causal_conv_1d
(more descriptive & proper capitalization) - We should work hard to avoid proliferation of environment variables. Could you come up with a reliable heuristic to choose between this triton implementation and the CUDA kernel instead of adding an env that exposes complexity to the end-user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I update the code accordingly. The Triton version use the layout of conv_state
cache that is the same layout with input tensor hidden_states_B_C
, i.e. contiguous along feature-dimension. The CUDA version in vLLM use a different layout of conv_state
cache that is contiguous along sequence-dimension. So, there is no good way to switch from one to another in the same session. It's the choice to be made, based on the expected workload. The benefit of Mamba-based model is in long context length and long output response, which we shows that Triton version would take over. Also, the Triton kernel follows vLLM v1 design, allowing prefill requests and decode requests to be mixed.
RESULT: long-context latency measurement (compared with Llama)
python benchmarks/benchmark_latency.py --model /net/storage149/autofs/css22/nmg/models/hf/meta-llama/Llama-3.1-8B-Instruct/main --input-len=131072 --output-len=1 --batch-size=1 --max_num_batched_tokens=2048
Avg latency: 11.214325266804856 seconds
10% percentile latency: 11.202042526123114 seconds
25% percentile latency: 11.206939334078925 seconds
50% percentile latency: 11.212064623483457 seconds
75% percentile latency: 11.220630767958937 seconds
90% percentile latency: 11.2278619370074 seconds
99% percentile latency: 11.24528882814222 seconds
- Main branch
# Default (max_num_batched_tokens=2048)
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1
Avg latency: 6.231618080474436 seconds
10% percentile latency: 6.204561746446416 seconds
25% percentile latency: 6.216710253240308 seconds
50% percentile latency: 6.219352831016295 seconds
75% percentile latency: 6.223808606999228 seconds
90% percentile latency: 6.227424801979214 seconds
99% percentile latency: 6.519547982601217 seconds
- Current PR:
export VLLM_USE_TRITON_CONV1D="1"
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1
Avg latency: 5.757278195097267 seconds
10% percentile latency: 5.734804188809358 seconds
25% percentile latency: 5.739403567742556 seconds
50% percentile latency: 5.743007940007374 seconds
75% percentile latency: 5.748229099262971 seconds
90% percentile latency: 5.751799210254103 seconds
99% percentile latency: 6.068630096325651 seconds
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
GSM8K RESULT
COMMAND TO RUN:
|
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
# tests correctness in case subset of the sequences are padded | ||
@pytest.mark.parametrize("with_padding", [True, False]) | ||
@pytest.mark.parametrize("batch_size", [3]) | ||
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you reduce the amount of time spend in this test? It's taking 102s, and we should try to keep unit test time under control
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also could you remove the _vllm
suffix?
The table shown shows a drop from 3628 to 3557 in total tok/s, which is only a 2% slowdown. Am I missing something? For simplicity, it would be better to remove the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these changes compatible with CUDA graphs?
def test_causal_conv1d_varlen_vllm(batch, with_padding, dim, seqlen, width, | ||
has_bias, silu_activation, itype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you rename this to test_causal_conv1d_varlen
? We don't need to clarify that it's for vllm, since this is in the vllm codebase
# tests correctness in case subset of the sequences are padded | ||
@pytest.mark.parametrize("with_padding", [True, False]) | ||
@pytest.mark.parametrize("batch_size", [3]) | ||
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also could you remove the _vllm
suffix?
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), | ||
self.conv1d.weight.size(2)) | ||
if use_triton_causal_conv_1d: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is too large and monolithic and should be refactored before landing. The if statements are too deeply nested and there is a lot of duplicated logic as well.
(Removing the CUDA implementations of causal_conv_1d
, would help with this)
@@ -22,6 +27,14 @@ class Mamba2Metadata: | |||
chunk_indices: torch.Tensor | |||
chunk_offsets: torch.Tensor | |||
|
|||
seqlens: Optional[np.ndarray] = None | |||
nums_dict: Optional[dict] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is nums_dict
? This should be documented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in a batch of requests, a prefill request can be processed in parallel where each Triton program handles BLOCK_M tokens. Depending on the choice of BLOCK_M, the values in batch_ptr
and token_chunk_offset_ptr
can be different. The choice of BLOCK_M can be different at different inputs, or different hardware. Currently, BLOCK_M is chosen as 8 and is the same across all inputs which is a good choice to avoid the overhead in Triton autotune.
nums_dict[BLOCK_M] = {batch_ptr, token_chunk_offset_ptr}
I added the documents accordingly.
seqlens: Optional[np.ndarray] = None | ||
nums_dict: Optional[dict] = None | ||
cu_seqlen: Optional[int] = None | ||
out: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain what this is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Triton causal-conv1d (prefill or mixed prefill/decode) kernel process input x
using conv_state
. While conv_state
update must be in-place, the output is written to out
tensor rather than writing to x
to avoid race condition - as each Triton program handles one segment of the request (unlike CUDA kernel where one thread block handles one full request). out
is reused across all layers .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since out
isn't metadata, please remove it from Mamba2Metadata
and treat it like a normal tensor
# keeping flags for both prefill and decode causal_conv1d varlen | ||
# [batch,] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify this comment? It's not clear to me what [batch,]
means in this context.
if use_triton_causal_conv_1d: | ||
has_initial_states = attn_metadata.context_lens_tensor > 0 | ||
prep_initial_states = torch.any( | ||
has_initial_states[:num_prefills]).item() | ||
else: | ||
has_initial_states = ( | ||
attn_metadata.context_lens_tensor[:num_prefills] > 0) | ||
# precompute flag to avoid device syncs later in mamba2 layer | ||
# forwards | ||
# prep is only needed for mamba2 ssd prefill processing | ||
prep_initial_states = torch.any(has_initial_states).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these need to be different?
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added my comments.
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None) | ||
if path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done update the code
num_cache_lines: Optional[int] = None | ||
stride_istate_seq: Optional[int] = None | ||
stride_istate_dim: Optional[int] = None | ||
stride_istate_token: Optional[int] = None | ||
seqlens: Optional[np.ndarray] = None | ||
padded_batch: Optional[int] = None | ||
nums_dict: Optional[dict] = None | ||
is_channel_last: bool = True | ||
stride_w_dim: Optional[int] = None | ||
stride_w_width: Optional[int] = None | ||
width: Optional[int] = None | ||
np2_statelen: Optional[int] = None | ||
stride_x_seq: Optional[int] = 0 | ||
stride_x_dim: Optional[int] = None | ||
stride_x_token: Optional[int] = None | ||
dim: Optional[int] = None | ||
cu_seqlen: Optional[int] = None | ||
out: Optional[torch.Tensor] = None | ||
stride_o_seq: Optional[int] = 0 | ||
stride_o_dim: Optional[int] = None | ||
stride_o_token: Optional[int] = None | ||
MAX_NUM_PROGRAMS: int = 1024 | ||
batch_ptr: Optional[torch.tensor] = None | ||
token_chunk_offset_ptr: Optional[torch.tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done updated code.
"""Create a weight loader for mamba v2. This ensures that the projections | ||
are correctly sharded so that they can be split into x, B, C. It also | ||
ensures that all the groups corresponding to a head shard is placed | ||
"""Create a weight loader for mamba v2. This ensures that the projections | ||
are correctly sharded so that they can be split into x, B, C. It also | ||
ensures the the all the groups corresponding to a head shard is placed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing | ||
# TODO: maybe revert back to the original code (below) if above no longer holds | ||
# has_initial_states = attn_metadata.context_lens_tensor > 0 | ||
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states] | ||
# mamba_cache_params.ssm_state[zero_init_indices] = 0 | ||
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed comment
[torch.float32, torch.float16, torch.bfloat16]) | ||
@pytest.mark.parametrize("silu_activation", [False, True]) | ||
@pytest.mark.parametrize("has_bias", [False, True]) | ||
@pytest.mark.parametrize("seqlen", [1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
vllm/envs.py
Outdated
@@ -125,7 +125,6 @@ | |||
VLLM_ALL2ALL_BACKEND: str = "naive" | |||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 | |||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 | |||
VLLM_USE_TRITON_CONV1D: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we no longer need this environment variable after removing CUDA kernel
metadata=mamba2_metadata, | ||
query_start_loc=attn_metadata.query_start_loc).transpose( | ||
0, 1)[:seq_len] | ||
# causal_conv1d_fn deals with both prefill and decode if input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This now follows the original code pathway. There is no need to split requests into prefill requests and decode requests for causal-conv1d kernel. This reverts the change made by a previous PR #17146. However, the other SSD kernels are still not designed for mixed prefill/decode-requests and therefore the separation is still present (as you see at the end)
) | ||
# Separate prefill and decode by splitting varlen input | ||
# Split along token dimension | ||
hidden_states_p, hidden_states_d = torch.split( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SSD kernels are still not designed for mixed prefill/decode-requests and therefore the separation is still present here (a contribution from another PR). Hopefully, a better algorithmic design for vLLM can address this in the future.
weight: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
query_start_loc: Optional[torch.Tensor] = None, | ||
cache_indices: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the wrapper to CUDA kernel.
@triton.jit() | ||
def _causal_conv1d_update_kernel( | ||
# Pointers to matrices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nothing change here, I move the causal_conv1d_fwd
kernel to the front of the causal_conv1d_update
, and it triggers a messy diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks better! I left a few more comments.
I think this will land after #19327 so we'll need to handle the attn metadata introduced by that PR as well.
channel_last = True | ||
if not channel_last: | ||
x = torch.randn(padded_batch_size, | ||
dim, | ||
seqlen, | ||
device=device, | ||
dtype=itype) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this?
if not channel_last: | ||
conv_state = torch.randn(total_entries, | ||
dim, | ||
width - 1, | ||
device=device, | ||
dtype=itype) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
channel_last = True | ||
if not channel_last: | ||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, | ||
dtype=itype)[:, 4096:4096 + dim, :] | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, remove
if not channel_last: | ||
final_states = torch.randn(total_entries, | ||
dim, | ||
width - 1, | ||
device=x.device, | ||
dtype=x.dtype) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@pytest.mark.parametrize( | ||
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) | ||
@pytest.mark.parametrize('seqlen', [8, 2049, 4096]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change looks good to me - the previous number of cases seems like overkill. Could you add a couple more non-power-of-two problem sizes to be safe.
seqlens: Optional[np.ndarray] = None | ||
nums_dict: Optional[dict] = None | ||
cu_seqlen: Optional[int] = None | ||
out: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since out
isn't metadata, please remove it from Mamba2Metadata
and treat it like a normal tensor
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@tlrmchlsmth : I merged the recent changes in main, and adopt the contribution from PR #19327. Please let me know if there is something else I should revise as well. To remove completely |
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
basic-correctness-test errors are related |
circular import issue Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
fixed that circular import issue. |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
This PR adds Triton-based causal-conv1d, making Mamba-based models in vLLM
There are two kernels implemented
It also performs better than the CUDA-split pathway which was merged as PR #17146.

[data benchmarking runtime processing the same batch of mixed requests, first send the batch to the single Triton kernel, and then using CUDA-split pathway where requests are first separated, with prefill-only requests are sent to one kernel, and decode-only requests are sent to the second kernel]
ALGORITHMIC CHOICE OF TRITON KERNEL: Unlike CUDA kernel which is implemented with parallelism in 2D, i.e. along feature-dimension, and batch size; Triton kernel is implemented with parallelism in 3D, i.e. along also sequence-dimension. Also, the Triton kernels don't make any changes to the layout of the input data which is contiguous along the feature-dimension. Another key difference is that Triton kernels expect the conv-state to be contiguous along the feature-dimension, while in existing CUDA implementation, it expects the conv-state cache to be contiguous along the kernel-width (i.e. sequence-length) axis. Nevertheless, the two CUDA kernels are not compatible with the layout of conv-state cache, and therefore prevents the efficient processing in decode-only requests or mixed prefill/decode-requests.
Also, some other improvement in reducing overhead is incorporated.
Even though binary code generated from Triton is faster, the launch overhead is a known issue and is therefore need further optimization to get the E2E Triton-only Mamba models in vLLM performant. Here, we also incorporate such improvements by using a metadata that can be reused across layers.
In our benchmark on ShareGPT dataset which has short input prompt (a few hundreds of tokens)
default setting: generates short number of tokens (i.e. 256 tokens) CUDA-backed Bamba (
ibm-ai-platform/Bamba-9B
) is still faster; 10% slower (total token throughput) yet only 2% in output token throughput and 2% in TTFT.generating 1024 tokens: Triton-backed Bamba is now faster with 5% faster on token throughput; and with 11% faster on TTFT. The benefit of faster Triton kernels now exceeds the overall costs of Triton launch overhead.
In the longer context length and/or longer number of generated tokens, Triton-only Mamba-based model is expected to be better than CUDA-split approach. However, the PR maintains the existing CUDA pathway as the default one until it is adopted by vLLM maintainers. Currently, the code is added as an optional pathway to the CUDA-split via
VLLM_USE_TRITON_CONV1D
environment variable set to 1.This is one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.
Test code is also added