@@ -31,6 +31,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
3131 const int * __restrict__ batch_id_per_token, // [num_tokens]
3232 const int * __restrict__ cu_seqlens_q,
3333 const int * __restrict__ seq_lens_decoder, // [bsz]
34+ const int * __restrict__ seq_lens_encoder, // [bsz]
3435 const float * __restrict__ cos_emb,
3536 const float * __restrict__ sin_emb,
3637 const float *
@@ -75,7 +76,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
7576
7677 const int ori_bi = batch_id_per_token[token_id];
7778 if (ori_bi == -1 ) continue ; // NOTE(gongshaotian): For CUDAGraph padding
78- if (seq_lens_decoder [ori_bi] == 0 ) continue ;
79+ if (seq_lens_encoder [ori_bi] > 0 ) continue ;
7980 const int bias = linear_index % hidden_size;
8081 const int hi = bias / head_size; // q + k + v
8182 const int h_bias = bias % head_size;
@@ -87,7 +88,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
8788 const int * block_table_now = block_tables + ori_bi * max_blocks_per_seq;
8889 const int block_idx = block_table_now[write_seq_id / block_size];
8990 if (block_idx < 0 ) {
90- return ; // NOTE(gongshaotian): For CUDAGraph padding
91+ continue ; // NOTE(gongshaotian): For CUDAGraph padding
9192 }
9293 const int block_offset = write_seq_id % block_size;
9394
@@ -343,6 +344,7 @@ __global__ void append_speculate_cache_rope_kernel(
343344 const int * __restrict__ batch_id_per_token, // [num_tokens]
344345 const int * __restrict__ cu_seqlens_q,
345346 const int * __restrict__ seq_lens_decoder, // [bsz]
347+ const int * __restrict__ seq_lens_encoder, // [bsz]
346348 const float * __restrict__ cos_emb,
347349 const float * __restrict__ sin_emb,
348350 const float *
@@ -380,7 +382,7 @@ __global__ void append_speculate_cache_rope_kernel(
380382 const int ori_bi = batch_id_per_token[token_id];
381383 if (ori_bi == -1 ) continue ; // NOTE(gongshaotian): For CUDAGraph padding
382384
383- if (seq_lens_decoder [ori_bi] == 0 ) continue ;
385+ if (seq_lens_encoder [ori_bi] > 0 ) continue ;
384386 const int bias = linear_index % hidden_size;
385387 const int hi = bias / head_size; // q + k + v
386388 const int h_bias = bias % head_size;
@@ -392,7 +394,7 @@ __global__ void append_speculate_cache_rope_kernel(
392394 const int * block_table_now = block_tables + ori_bi * max_blocks_per_seq;
393395 const int block_idx = block_table_now[write_seq_id / block_size];
394396 if (block_idx < 0 ) {
395- return ; // NOTE(gongshaotian): For CUDAGraph padding
397+ continue ; // NOTE(gongshaotian): For CUDAGraph padding
396398 }
397399 const int block_offset = write_seq_id % block_size;
398400
@@ -473,6 +475,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
473475 const int * __restrict__ batch_id_per_token, // [num_tokens]
474476 const int * __restrict__ cu_seqlens_q,
475477 const int * __restrict__ seq_lens_decoder, // [bsz]
478+ const int * __restrict__ seq_lens_encoder, // [bsz]
476479 const float * __restrict__ cos_emb,
477480 const float * __restrict__ sin_emb,
478481 const float *
@@ -509,7 +512,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
509512 const int token_id = linear_index / half_hidden_size;
510513 const int ori_bi = batch_id_per_token[token_id];
511514 if (ori_bi == -1 ) continue ; // NOTE(gongshaotian): For CUDAGraph padding
512- if (seq_lens_decoder [ori_bi] == 0 ) continue ;
515+ if (seq_lens_encoder [ori_bi] > 0 ) continue ;
513516 const int bias = linear_index % half_hidden_size;
514517 const int hi = bias / half_head_size; // q + k + v
515518 const int h_bias = bias % half_head_size;
@@ -521,7 +524,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
521524 const int * block_table_now = block_tables + ori_bi * max_blocks_per_seq;
522525 const int block_idx = block_table_now[write_seq_id / block_size];
523526 if (block_idx < 0 ) {
524- return ; // NOTE(gongshaotian): For CUDAGraph padding
527+ continue ; // NOTE(gongshaotian): For CUDAGraph padding
525528 }
526529 const int block_offset = write_seq_id % block_size;
527530
0 commit comments