77from einops import rearrange , repeat
88
99from vllm .model_executor .layers .mamba .ops .ssd_combined import (
10- mamba_chunk_scan_combined )
10+ mamba_chunk_scan_combined_varlen )
1111from vllm .platforms import current_platform
1212from vllm .v1 .attention .backends .mamba2_attn import (
1313 _query_start_loc_to_chunk_indices_offsets )
@@ -185,9 +185,14 @@ def end_boundary(n: int):
185185 IND_S = [x % full_length for x in IND_E ]
186186 IND_E = [end_boundary (x + y ) for x , y in zip (IND_S , spec )]
187187
188+ # varlen has implicit batch=1
189+ dt2 = dt2 .squeeze (0 )
190+ X2 = X2 .squeeze (0 )
191+ B2 = B2 .squeeze (0 )
192+ C2 = C2 .squeeze (0 )
188193 yield ([Y_min [s , IND_S [s ]:IND_E [s ]]
189194 for s in range (num_examples )] if return_naive_ref else None ,
190- cu_seqlens , seq_idx . unsqueeze ( 0 ) , (A , dt2 , X2 , B2 , C2 ))
195+ cu_seqlens , seq_idx , (A , dt2 , X2 , B2 , C2 ))
191196
192197
193198@pytest .mark .parametrize ("itype" ,
@@ -198,7 +203,7 @@ def end_boundary(n: int):
198203def test_mamba_chunk_scan_single_example (d_head , n_heads , seq_len_chunk_size ,
199204 itype ):
200205
201- # this tests the kernels on a single example (no batching )
206+ # this tests the kernels on a single example (bs=1 )
202207
203208 # TODO: the bfloat16 case requires higher thresholds. To be investigated
204209
@@ -219,23 +224,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
219224
220225 Y_min , final_state_min = ssd_minimal_discrete (X * dt .unsqueeze (- 1 ), A * dt ,
221226 B , C , chunk_size )
227+
228+ cu_seqlens = torch .tensor ((0 , seqlen ), device = 'cuda' ).cumsum (dim = 0 )
229+ seq_idx = torch .zeros (seqlen , dtype = torch .int32 , device = cu_seqlens .device )
230+
231+ chunk_indices , chunk_offsets = \
232+ _query_start_loc_to_chunk_indices_offsets (
233+ cu_seqlens , chunk_size , cu_seqlens [- 1 ])
234+
235+ # varlen has implicit batch=1
236+ X = X .squeeze (0 )
237+ dt = dt .squeeze (0 )
238+ A = A .squeeze (0 )
239+ B = B .squeeze (0 )
240+ C = C .squeeze (0 )
222241 Y = torch .empty_like (X )
223- final_state = mamba_chunk_scan_combined (X ,
224- dt ,
225- A ,
226- B ,
227- C ,
228- chunk_size ,
229- D = None ,
230- return_final_states = True ,
231- out = Y )
242+ final_state = mamba_chunk_scan_combined_varlen (X ,
243+ dt ,
244+ A ,
245+ B ,
246+ C ,
247+ chunk_size ,
248+ D = None ,
249+ cu_seqlens = cu_seqlens ,
250+ seq_idx = seq_idx ,
251+ chunk_indices = chunk_indices ,
252+ chunk_offsets = chunk_offsets ,
253+ out = Y )
232254
233255 # just test the last in sequence
234- torch .testing .assert_close (Y [:, - 1 ], Y_min [: , - 1 ], atol = atol , rtol = rtol )
256+ torch .testing .assert_close (Y [- 1 ], Y_min [0 , - 1 ], atol = atol , rtol = rtol )
235257
236258 # just test the last head
237259 # NOTE, in the kernel we always cast states to fp32
238- torch .testing .assert_close (final_state [:, - 1 ],
260+ torch .testing .assert_close (final_state [:, - 1 ]. to ( torch . float32 ) ,
239261 final_state_min [:, - 1 ].to (torch .float32 ),
240262 atol = atol ,
241263 rtol = rtol )
@@ -300,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
300322 cu_seqlens , chunk_size , cu_seqlens [- 1 ])
301323
302324 Y = torch .empty_like (X )
303- new_states = mamba_chunk_scan_combined (
325+ new_states = mamba_chunk_scan_combined_varlen (
304326 X ,
305327 dt ,
306328 A ,
@@ -312,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
312334 seq_idx = seq_idx ,
313335 chunk_indices = chunk_indices ,
314336 chunk_offsets = chunk_offsets ,
315- return_varlen_states = True ,
316337 initial_states = states ,
317338 out = Y ,
318339 )
@@ -321,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
321342 for i in range (num_examples ):
322343
323344 # just test one dim and dstate
324- Y_eg = Y [0 , cu_seqlens [i ]:cu_seqlens [i + 1 ], 0 , 0 ]
345+ Y_eg = Y [cu_seqlens [i ]:cu_seqlens [i + 1 ], 0 , 0 ]
325346 Y_min_eg = Y_min [i ][:, 0 , 0 ]
326347 torch .testing .assert_close (Y_eg , Y_min_eg , atol = atol , rtol = rtol )
327348
@@ -386,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
386407 _query_start_loc_to_chunk_indices_offsets (
387408 cu_seqlens , chunk_size , cu_seqlens [- 1 ])
388409 Y_ref = torch .empty_like (X )
389- state_ref = mamba_chunk_scan_combined (
410+ state_ref = mamba_chunk_scan_combined_varlen (
390411 X ,
391412 dt ,
392413 A ,
@@ -398,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
398419 seq_idx = seq_idx ,
399420 chunk_indices = chunk_indices ,
400421 chunk_offsets = chunk_offsets ,
401- return_varlen_states = True ,
402422 initial_states = None ,
403423 out = Y_ref ,
404424 )
@@ -414,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
414434 chunked_seq_idx = torch .repeat_interleave (
415435 torch .arange (len (chunked_seqlens ), device = device ),
416436 chunked_seqlens ,
417- output_size = chunked_cu_seqlens [- 1 ]).unsqueeze ( 0 ). to (torch .int32 )
437+ output_size = chunked_cu_seqlens [- 1 ]).to (torch .int32 )
418438 chunked_input_seq_len = chunked_cu_seqlens [- 1 ]
419- X_chunked = torch .zeros_like (X )[:, : chunked_input_seq_len , ...]
420- dt_chunked = torch .zeros_like (dt )[:, : chunked_input_seq_len , ...]
421- B_chunked = torch .zeros_like (B )[:, : chunked_input_seq_len , ...]
422- C_chunked = torch .zeros_like (C )[:, : chunked_input_seq_len , ...]
439+ X_chunked = torch .zeros_like (X )[:chunked_input_seq_len , ...]
440+ dt_chunked = torch .zeros_like (dt )[:chunked_input_seq_len , ...]
441+ B_chunked = torch .zeros_like (B )[:chunked_input_seq_len , ...]
442+ C_chunked = torch .zeros_like (C )[:chunked_input_seq_len , ...]
423443 for i in range (num_sequences ):
424444 # fmt: off
425- chunk_f = lambda x , i : x [:, cu_seqlens [i ]:cu_seqlens [i ] + chunked_seqlens [i ], ...] # noqa: E501
445+ chunk_f = lambda x , i : x [cu_seqlens [i ]:cu_seqlens [i ] + chunked_seqlens [i ], ...] # noqa: E501
426446
427- X_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (X , i ) # noqa: E501
428- dt_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (dt , i ) # noqa: E501
429- B_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (B , i ) # noqa: E501
430- C_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (C , i ) # noqa: E501
447+ X_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (X , i ) # noqa: E501
448+ dt_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (dt , i ) # noqa: E501
449+ B_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (B , i ) # noqa: E501
450+ C_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (C , i ) # noqa: E501
431451 # fmt: on
432452
433453 chunk_indices , chunk_offsets = \
434454 _query_start_loc_to_chunk_indices_offsets (
435455 chunked_cu_seqlens , chunk_size , chunked_cu_seqlens [- 1 ])
436456 Y_partial = torch .empty_like (X_chunked )
437- partial_state = mamba_chunk_scan_combined (
457+ partial_state = mamba_chunk_scan_combined_varlen (
438458 X_chunked ,
439459 dt_chunked ,
440460 A ,
@@ -446,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
446466 seq_idx = chunked_seq_idx ,
447467 chunk_indices = chunk_indices ,
448468 chunk_offsets = chunk_offsets ,
449- return_varlen_states = True ,
450469 initial_states = None ,
451470 out = Y_partial ,
452471 )
@@ -461,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
461480 remaining_chunked_seq_idx = torch .repeat_interleave (
462481 torch .arange (len (remaining_chunked_seqlens ), device = device ),
463482 remaining_chunked_seqlens ,
464- output_size = remaining_chunked_cu_seqlens [- 1 ]).unsqueeze (0 ).to (
465- torch .int32 )
483+ output_size = remaining_chunked_cu_seqlens [- 1 ]).to (torch .int32 )
466484 remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens [- 1 ]
467485 # fmt: off
468- remaining_X_chunked = torch .zeros_like (X )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
469- remaining_dt_chunked = torch .zeros_like (dt )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
470- remaining_B_chunked = torch .zeros_like (B )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
471- remaining_C_chunked = torch .zeros_like (C )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
486+ remaining_X_chunked = torch .zeros_like (X )[:remaining_chunked_input_seq_len , ...] # noqa: E501
487+ remaining_dt_chunked = torch .zeros_like (dt )[:remaining_chunked_input_seq_len , ...] # noqa: E501
488+ remaining_B_chunked = torch .zeros_like (B )[:remaining_chunked_input_seq_len , ...] # noqa: E501
489+ remaining_C_chunked = torch .zeros_like (C )[:remaining_chunked_input_seq_len , ...] # noqa: E501
472490 for i in range (num_sequences ):
473- remaining_chunk_f = lambda x , i : x [:, cu_seqlens [i ] + chunked_seqlens [i ]:cu_seqlens [i + 1 ], ...] # noqa: E501
491+ remaining_chunk_f = lambda x , i : x [cu_seqlens [i ] + chunked_seqlens [i ]:cu_seqlens [i + 1 ], ...] # noqa: E501
474492
475- remaining_X_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (X , i ) # noqa: E501
476- remaining_dt_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (dt , i ) # noqa: E501
477- remaining_B_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (B , i ) # noqa: E501
478- remaining_C_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (C , i ) # noqa: E501
493+ remaining_X_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (X , i ) # noqa: E501
494+ remaining_dt_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (dt , i ) # noqa: E501
495+ remaining_B_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (B , i ) # noqa: E501
496+ remaining_C_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (C , i ) # noqa: E501
479497
480498 # assert input chunking is correct
481499 concat_chunk_f = lambda pt1 , pt2 , i : torch .cat ([
482- pt1 [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ],...],
483- pt2 [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ],...],
500+ pt1 [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ],...],
501+ pt2 [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ],...],
484502 ],
485- dim = 1 )
486- concat_batch_f = lambda pt1 , pt2 : torch .cat ([concat_chunk_f (pt1 , pt2 , i ) for i in range (num_sequences )], dim = 1 ) # noqa: E501
503+ dim = 0 )
504+ concat_batch_f = lambda pt1 , pt2 : torch .cat ([concat_chunk_f (pt1 , pt2 , i ) for i in range (num_sequences )], dim = 0 ) # noqa: E501
487505 # fmt: on
488506
489507 assert concat_batch_f (X_chunked , remaining_X_chunked ).equal (X )
@@ -498,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
498516 remaining_chunked_cu_seqlens [- 1 ])
499517
500518 Y_chunked = torch .empty_like (remaining_X_chunked )
501- state_chunked = mamba_chunk_scan_combined (
519+ state_chunked = mamba_chunk_scan_combined_varlen (
502520 remaining_X_chunked ,
503521 remaining_dt_chunked ,
504522 A ,
@@ -510,25 +528,24 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
510528 seq_idx = remaining_chunked_seq_idx ,
511529 chunk_indices = chunk_indices ,
512530 chunk_offsets = chunk_offsets ,
513- return_varlen_states = True ,
514531 initial_states = partial_state ,
515532 out = Y_chunked ,
516533 )
517534 Y = concat_batch_f (Y_partial , Y_chunked )
518535
519536 # kernel chunked is same as kernel overall
520537 for i in range (num_sequences ):
521- Y_seq = Y [:, cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
522- Y_ref_seq = Y_ref [:, cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
538+ Y_seq = Y [cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
539+ Y_ref_seq = Y_ref [cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
523540 torch .testing .assert_close (
524- Y_seq [:, : chunked_seqlens [i ], ...],
525- Y_ref_seq [:, : chunked_seqlens [i ], ...],
541+ Y_seq [:chunked_seqlens [i ], ...],
542+ Y_ref_seq [:chunked_seqlens [i ], ...],
526543 atol = atol ,
527544 rtol = rtol ,
528545 msg = lambda x : f"seq{ i } output part1 " + x ) # noqa: B023
529546 torch .testing .assert_close (
530- Y_seq [:, chunked_seqlens [i ]:, ...],
531- Y_ref_seq [:, chunked_seqlens [i ]:, ...],
547+ Y_seq [chunked_seqlens [i ]:, ...],
548+ Y_ref_seq [chunked_seqlens [i ]:, ...],
532549 atol = atol ,
533550 rtol = rtol ,
534551 msg = lambda x : f"seq{ i } output part2 " + x ) # noqa: B023
0 commit comments