Skip to content

Commit 62ae26c

Browse files
cyang49RishiAstra
authored andcommitted
[Model] Mamba2 varlen refactor (#21467)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 87ee853 commit 62ae26c

File tree

10 files changed

+723
-865
lines changed

10 files changed

+723
-865
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 71 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from einops import rearrange, repeat
88

99
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
10-
mamba_chunk_scan_combined)
10+
mamba_chunk_scan_combined_varlen)
1111
from vllm.platforms import current_platform
1212
from vllm.v1.attention.backends.mamba2_attn import (
1313
_query_start_loc_to_chunk_indices_offsets)
@@ -185,9 +185,14 @@ def end_boundary(n: int):
185185
IND_S = [x % full_length for x in IND_E]
186186
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
187187

188+
# varlen has implicit batch=1
189+
dt2 = dt2.squeeze(0)
190+
X2 = X2.squeeze(0)
191+
B2 = B2.squeeze(0)
192+
C2 = C2.squeeze(0)
188193
yield ([Y_min[s, IND_S[s]:IND_E[s]]
189194
for s in range(num_examples)] if return_naive_ref else None,
190-
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
195+
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
191196

192197

193198
@pytest.mark.parametrize("itype",
@@ -198,7 +203,7 @@ def end_boundary(n: int):
198203
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
199204
itype):
200205

201-
# this tests the kernels on a single example (no batching)
206+
# this tests the kernels on a single example (bs=1)
202207

203208
# TODO: the bfloat16 case requires higher thresholds. To be investigated
204209

@@ -219,23 +224,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
219224

220225
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
221226
B, C, chunk_size)
227+
228+
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
229+
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
230+
231+
chunk_indices, chunk_offsets = \
232+
_query_start_loc_to_chunk_indices_offsets(
233+
cu_seqlens, chunk_size, cu_seqlens[-1])
234+
235+
# varlen has implicit batch=1
236+
X = X.squeeze(0)
237+
dt = dt.squeeze(0)
238+
A = A.squeeze(0)
239+
B = B.squeeze(0)
240+
C = C.squeeze(0)
222241
Y = torch.empty_like(X)
223-
final_state = mamba_chunk_scan_combined(X,
224-
dt,
225-
A,
226-
B,
227-
C,
228-
chunk_size,
229-
D=None,
230-
return_final_states=True,
231-
out=Y)
242+
final_state = mamba_chunk_scan_combined_varlen(X,
243+
dt,
244+
A,
245+
B,
246+
C,
247+
chunk_size,
248+
D=None,
249+
cu_seqlens=cu_seqlens,
250+
seq_idx=seq_idx,
251+
chunk_indices=chunk_indices,
252+
chunk_offsets=chunk_offsets,
253+
out=Y)
232254

233255
# just test the last in sequence
234-
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
256+
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
235257

236258
# just test the last head
237259
# NOTE, in the kernel we always cast states to fp32
238-
torch.testing.assert_close(final_state[:, -1],
260+
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
239261
final_state_min[:, -1].to(torch.float32),
240262
atol=atol,
241263
rtol=rtol)
@@ -300,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
300322
cu_seqlens, chunk_size, cu_seqlens[-1])
301323

302324
Y = torch.empty_like(X)
303-
new_states = mamba_chunk_scan_combined(
325+
new_states = mamba_chunk_scan_combined_varlen(
304326
X,
305327
dt,
306328
A,
@@ -312,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
312334
seq_idx=seq_idx,
313335
chunk_indices=chunk_indices,
314336
chunk_offsets=chunk_offsets,
315-
return_varlen_states=True,
316337
initial_states=states,
317338
out=Y,
318339
)
@@ -321,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
321342
for i in range(num_examples):
322343

323344
# just test one dim and dstate
324-
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
345+
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
325346
Y_min_eg = Y_min[i][:, 0, 0]
326347
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
327348

@@ -386,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
386407
_query_start_loc_to_chunk_indices_offsets(
387408
cu_seqlens, chunk_size, cu_seqlens[-1])
388409
Y_ref = torch.empty_like(X)
389-
state_ref = mamba_chunk_scan_combined(
410+
state_ref = mamba_chunk_scan_combined_varlen(
390411
X,
391412
dt,
392413
A,
@@ -398,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
398419
seq_idx=seq_idx,
399420
chunk_indices=chunk_indices,
400421
chunk_offsets=chunk_offsets,
401-
return_varlen_states=True,
402422
initial_states=None,
403423
out=Y_ref,
404424
)
@@ -414,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
414434
chunked_seq_idx = torch.repeat_interleave(
415435
torch.arange(len(chunked_seqlens), device=device),
416436
chunked_seqlens,
417-
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
437+
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
418438
chunked_input_seq_len = chunked_cu_seqlens[-1]
419-
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
420-
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
421-
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
422-
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
439+
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
440+
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
441+
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
442+
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
423443
for i in range(num_sequences):
424444
# fmt: off
425-
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
445+
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
426446

427-
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
428-
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
429-
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
430-
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
447+
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
448+
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
449+
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
450+
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
431451
# fmt: on
432452

433453
chunk_indices, chunk_offsets = \
434454
_query_start_loc_to_chunk_indices_offsets(
435455
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
436456
Y_partial = torch.empty_like(X_chunked)
437-
partial_state = mamba_chunk_scan_combined(
457+
partial_state = mamba_chunk_scan_combined_varlen(
438458
X_chunked,
439459
dt_chunked,
440460
A,
@@ -446,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
446466
seq_idx=chunked_seq_idx,
447467
chunk_indices=chunk_indices,
448468
chunk_offsets=chunk_offsets,
449-
return_varlen_states=True,
450469
initial_states=None,
451470
out=Y_partial,
452471
)
@@ -461,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
461480
remaining_chunked_seq_idx = torch.repeat_interleave(
462481
torch.arange(len(remaining_chunked_seqlens), device=device),
463482
remaining_chunked_seqlens,
464-
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
465-
torch.int32)
483+
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
466484
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
467485
# fmt: off
468-
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
469-
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
470-
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
471-
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
486+
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
487+
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
488+
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
489+
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
472490
for i in range(num_sequences):
473-
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
491+
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
474492

475-
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
476-
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
477-
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
478-
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
493+
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
494+
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
495+
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
496+
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
479497

480498
# assert input chunking is correct
481499
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
482-
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
483-
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
500+
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
501+
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
484502
],
485-
dim=1)
486-
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
503+
dim=0)
504+
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
487505
# fmt: on
488506

489507
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
@@ -498,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
498516
remaining_chunked_cu_seqlens[-1])
499517

500518
Y_chunked = torch.empty_like(remaining_X_chunked)
501-
state_chunked = mamba_chunk_scan_combined(
519+
state_chunked = mamba_chunk_scan_combined_varlen(
502520
remaining_X_chunked,
503521
remaining_dt_chunked,
504522
A,
@@ -510,25 +528,24 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
510528
seq_idx=remaining_chunked_seq_idx,
511529
chunk_indices=chunk_indices,
512530
chunk_offsets=chunk_offsets,
513-
return_varlen_states=True,
514531
initial_states=partial_state,
515532
out=Y_chunked,
516533
)
517534
Y = concat_batch_f(Y_partial, Y_chunked)
518535

519536
# kernel chunked is same as kernel overall
520537
for i in range(num_sequences):
521-
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
522-
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
538+
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
539+
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
523540
torch.testing.assert_close(
524-
Y_seq[:, :chunked_seqlens[i], ...],
525-
Y_ref_seq[:, :chunked_seqlens[i], ...],
541+
Y_seq[:chunked_seqlens[i], ...],
542+
Y_ref_seq[:chunked_seqlens[i], ...],
526543
atol=atol,
527544
rtol=rtol,
528545
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
529546
torch.testing.assert_close(
530-
Y_seq[:, chunked_seqlens[i]:, ...],
531-
Y_ref_seq[:, chunked_seqlens[i]:, ...],
547+
Y_seq[chunked_seqlens[i]:, ...],
548+
Y_ref_seq[chunked_seqlens[i]:, ...],
532549
atol=atol,
533550
rtol=rtol,
534551
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
3030
selective_state_update)
3131
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
32-
mamba_chunk_scan_combined)
32+
mamba_chunk_scan_combined_varlen)
3333
from vllm.model_executor.layers.quantization import QuantizationConfig
3434
from vllm.model_executor.model_loader.weight_utils import (
3535
LoaderFunction, composed_weight_loader, sharded_weight_loader)
@@ -504,6 +504,7 @@ def forward_cuda(
504504
seq_idx_p = attn_metadata.seq_idx_p
505505
chunk_indices_p = attn_metadata.chunk_indices_p
506506
chunk_offsets_p = attn_metadata.chunk_offsets_p
507+
query_start_loc_p = attn_metadata.query_start_loc_p
507508

508509
# 1. Gated MLP's linear projection
509510
projected_states, _ = self.in_proj(hidden_states)
@@ -545,6 +546,7 @@ def forward_cuda(
545546
out, _ = self.out_proj(hidden_states)
546547
return out
547548

549+
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
548550
num_prefills = attn_metadata.num_prefills # request count
549551
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
550552
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
@@ -570,9 +572,6 @@ def forward_cuda(
570572
[num_decodes, num_prefills],
571573
dim=0,
572574
)
573-
query_start_loc_p = (
574-
attn_metadata.query_start_loc[-num_prefills - 1:] -
575-
num_decodes if has_prefill else None)
576575

577576
# Preallocate output tensor to avoid memcpy cost for merging prefill
578577
# and decode outputs
@@ -620,15 +619,15 @@ def forward_cuda(
620619
ssm_state[state_indices_tensor_p], 0)
621620

622621
# NOTE: final output is an in-place update of out tensor
623-
varlen_state = mamba_chunk_scan_combined(
624-
hidden_states_p.view(1, num_prefill_tokens,
622+
varlen_states = mamba_chunk_scan_combined_varlen(
623+
hidden_states_p.view(num_prefill_tokens,
625624
self.num_heads // self.tp_size,
626625
self.head_dim),
627-
dt_p.unsqueeze(0),
626+
dt_p,
628627
self.A,
629-
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
628+
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
630629
-1),
631-
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
630+
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
632631
-1),
633632
chunk_size=chunk_size,
634633
D=self.D,
@@ -639,17 +638,15 @@ def forward_cuda(
639638
chunk_offsets=chunk_offsets_p,
640639
cu_seqlens=query_start_loc_p,
641640
initial_states=initial_states,
642-
return_varlen_states=True,
643-
return_final_states=False,
644641
dt_softplus=True,
645642
dt_limit=(0.0, float("inf")),
646-
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
643+
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
647644
self.head_dim),
648645
state_dtype=ssm_state.dtype)
649646

650647
# update ssm states
651648
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
652-
ssm_state[state_indices_tensor_p] = varlen_state
649+
ssm_state[state_indices_tensor_p] = varlen_states
653650

654651
# Process decode requests
655652
if has_decode:

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def causal_conv1d_fn(
427427
batch_ptr = metadata.batch_ptr
428428
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
429429
else:
430-
seqlens = np.diff(query_start_loc.to('cpu'))
430+
seqlens = query_start_loc.diff().to('cpu')
431431
args = seqlens
432432
MAX_NUM_PROGRAMS = 1024
433433

0 commit comments

Comments
 (0)