@@ -47,9 +47,9 @@ class GDNAttentionMetadata:
4747 None # shape: [batch - num_spec_decodes,]
4848 )
4949 spec_sequence_masks : torch .Tensor | None = None # shape: [batch,]
50- spec_token_masks : torch .Tensor | None = (
51- None # shape: [num_prefill_tokens + num_decode_tokens,]
52- )
50+ spec_token_indx : torch .Tensor | None = None
51+ non_spec_token_indx : torch . Tensor | None = None
52+
5353 num_accepted_tokens : torch .Tensor | None = None # shape: [batch,]
5454
5555 # The following attributes are for triton implementation of causal_conv1d
@@ -105,9 +105,14 @@ def __init__(
105105 dtype = torch .bool ,
106106 device = device ,
107107 )
108- self .spec_token_masks = torch .empty (
108+ self .spec_token_indx = torch .empty (
109109 (self .decode_cudagraph_max_bs * (self .num_spec + 1 ),),
110- dtype = torch .bool ,
110+ dtype = torch .int32 ,
111+ device = device ,
112+ )
113+ self .non_spec_token_indx = torch .empty (
114+ (self .decode_cudagraph_max_bs * (self .num_spec + 1 ),),
115+ dtype = torch .int32 ,
111116 device = device ,
112117 )
113118 self .spec_query_start_loc = torch .empty (
@@ -166,7 +171,8 @@ def build( # type: ignore[override]
166171 split_decodes_and_prefills (m , decode_threshold = 1 )
167172 )
168173 num_spec_decode_tokens = 0
169- spec_token_masks = None
174+ spec_token_indx = None
175+ non_spec_token_indx = None
170176 spec_state_indices_tensor = None
171177 non_spec_state_indices_tensor = m .block_table_tensor [:, 0 ]
172178 spec_query_start_loc = None
@@ -180,18 +186,23 @@ def build( # type: ignore[override]
180186 num_prefills = non_spec_query_lens .size (0 ) - num_decodes
181187 num_decode_tokens = num_decodes
182188 num_prefill_tokens = non_spec_query_lens .sum ().item () - num_decode_tokens
189+ num_spec_decode_tokens = (
190+ query_lens .sum ().item () - num_prefill_tokens - num_decode_tokens
191+ )
183192
184193 if num_prefills == 0 and num_decodes == 0 :
185- spec_token_masks = torch .ones (
186- (
187- min (
188- num_spec_decodes * (self .num_spec + 1 ),
189- query_start_loc [- 1 ].item (),
190- )
191- ),
192- dtype = torch .bool ,
194+ spec_token_size = min (
195+ num_spec_decodes * (self .num_spec + 1 ),
196+ query_start_loc [- 1 ].item (),
197+ )
198+ spec_token_indx = torch .arange (
199+ spec_token_size ,
200+ dtype = torch .int32 ,
193201 device = query_start_loc .device ,
194202 )
203+ non_spec_token_indx = torch .empty (
204+ 0 , dtype = torch .int32 , device = query_start_loc .device
205+ )
195206 spec_state_indices_tensor = m .block_table_tensor [:, : self .num_spec + 1 ]
196207 non_spec_state_indices_tensor = None
197208 spec_query_start_loc = query_start_loc
@@ -200,6 +211,11 @@ def build( # type: ignore[override]
200211 spec_token_masks = torch .repeat_interleave (
201212 spec_sequence_masks , query_lens
202213 )
214+ index = torch .argsort (spec_token_masks )
215+ num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
216+ non_spec_token_indx = index [:num_non_spec_tokens ]
217+ spec_token_indx = index [num_non_spec_tokens :]
218+
203219 spec_state_indices_tensor = m .block_table_tensor [
204220 spec_sequence_masks , : self .num_spec + 1
205221 ]
@@ -226,9 +242,6 @@ def build( # type: ignore[override]
226242 out = non_spec_query_start_loc [1 :],
227243 )
228244
229- num_spec_decode_tokens = (
230- query_lens .sum ().item () - num_prefill_tokens - num_decode_tokens
231- )
232245 assert num_accepted_tokens is not None
233246 num_accepted_tokens = num_accepted_tokens [spec_sequence_masks ]
234247
@@ -274,12 +287,18 @@ def build( # type: ignore[override]
274287 spec_sequence_masks = self .spec_sequence_masks [:batch_size ]
275288 spec_sequence_masks [num_spec_decodes :].fill_ (False )
276289
277- assert spec_token_masks is not None
278- self .spec_token_masks [: spec_token_masks .size (0 )].copy_ (
279- spec_token_masks , non_blocking = True
290+ assert non_spec_token_indx is not None and spec_token_indx is not None
291+ self .non_spec_token_indx [: non_spec_token_indx .size (0 )].copy_ (
292+ non_spec_token_indx , non_blocking = True
293+ )
294+ non_spec_token_indx = self .non_spec_token_indx [
295+ : non_spec_token_indx .size (0 )
296+ ]
297+
298+ self .spec_token_indx [: spec_token_indx .size (0 )].copy_ (
299+ spec_token_indx , non_blocking = True
280300 )
281- spec_token_masks = self .spec_token_masks [:num_actual_tokens ]
282- spec_token_masks [spec_token_masks .size (0 ) :].fill_ (False )
301+ spec_token_indx = self .spec_token_indx [: spec_token_indx .size (0 )]
283302
284303 self .spec_query_start_loc [: num_spec_decodes + 1 ].copy_ (
285304 spec_query_start_loc , non_blocking = True
@@ -332,7 +351,8 @@ def build( # type: ignore[override]
332351 spec_state_indices_tensor = spec_state_indices_tensor ,
333352 non_spec_state_indices_tensor = non_spec_state_indices_tensor ,
334353 spec_sequence_masks = spec_sequence_masks ,
335- spec_token_masks = spec_token_masks ,
354+ spec_token_indx = spec_token_indx ,
355+ non_spec_token_indx = non_spec_token_indx ,
336356 num_accepted_tokens = num_accepted_tokens ,
337357 nums_dict = nums_dict ,
338358 batch_ptr = batch_ptr ,
0 commit comments