@@ -168,13 +168,12 @@ class EagleVerifyInput:
168168 custom_mask : torch .Tensor
169169 positions : torch .Tensor
170170 retrive_index : torch .Tensor
171+ retrive_next_token : torch .Tensor
172+ retrive_next_sibling : torch .Tensor
171173 retrive_cum_len : torch .Tensor
172174 draft_token_num : int
175+ spec_steps : int
173176 capture_hidden_mode : CaptureHiddenMode
174- spec_steps : int = 0
175- retrive_next_token : torch .Tensor = None
176- retrive_next_sibling : torch .Tensor = None
177- non_greedy_retrive_index : torch .Tensor = None
178177
179178 @classmethod
180179 def create (
@@ -187,23 +186,45 @@ def create(
187186 seq_lens_sum : int ,
188187 topk : int ,
189188 spec_steps : int ,
190- num_verify_token : int ,
189+ num_verify_tokens : int ,
190+ is_all_greedy : bool ,
191191 ):
192- tree_mask , position , retrive_index , retrive_cum_len , draft_tokens = (
193- build_tree_kernel (
194- verified_id ,
195- score_list ,
196- token_list ,
197- parents_list ,
198- seq_lens ,
199- seq_lens_sum ,
200- topk ,
192+ if is_all_greedy :
193+ tree_mask , position , retrive_index , retrive_cum_len , draft_tokens = (
194+ build_tree_kernel (
195+ verified_id ,
196+ score_list , # b, n, topk; n= 1 + (num_steps-1) * self.topk
197+ token_list ,
198+ parents_list ,
199+ seq_lens ,
200+ seq_lens_sum ,
201+ topk ,
202+ spec_steps ,
203+ num_verify_tokens ,
204+ )
205+ )
206+
207+ return cls (
208+ draft_tokens ,
209+ tree_mask ,
210+ position ,
211+ retrive_index ,
212+ None ,
213+ None ,
214+ retrive_cum_len ,
215+ num_verify_tokens ,
201216 spec_steps ,
202- num_verify_token ,
217+ CaptureHiddenMode . FULL ,
203218 )
204- )
205- _ , _ , non_greedy_retrive_index , retrive_next_token , retrive_next_sibling , _ = (
206- build_tree_kernel_efficient (
219+ else :
220+ (
221+ tree_mask ,
222+ position ,
223+ retrive_index ,
224+ retrive_next_token ,
225+ retrive_next_sibling ,
226+ draft_tokens ,
227+ ) = build_tree_kernel_efficient (
207228 verified_id ,
208229 score_list ,
209230 token_list ,
@@ -212,22 +233,21 @@ def create(
212233 seq_lens_sum ,
213234 topk ,
214235 spec_steps ,
215- num_verify_token ,
236+ num_verify_tokens ,
237+ )
238+
239+ return cls (
240+ draft_tokens ,
241+ tree_mask ,
242+ position ,
243+ retrive_index ,
244+ retrive_next_token ,
245+ retrive_next_sibling ,
246+ None ,
247+ num_verify_tokens ,
248+ spec_steps ,
249+ CaptureHiddenMode .FULL ,
216250 )
217- )
218- return cls (
219- draft_tokens ,
220- tree_mask ,
221- position ,
222- retrive_index ,
223- retrive_cum_len ,
224- num_verify_token ,
225- CaptureHiddenMode .FULL ,
226- spec_steps ,
227- retrive_next_token ,
228- retrive_next_sibling ,
229- non_greedy_retrive_index ,
230- )
231251
232252 def prepare_for_verify (self , batch : ScheduleBatch ):
233253 batch .input_ids = self .draft_token
@@ -283,9 +303,9 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
283303 [self .draft_token , torch .full ([1 ], - 1 , dtype = torch .int32 , device = "cuda" )],
284304 dim = - 1 ,
285305 )
306+ candidates = draft_token [self .retrive_index ]
286307 if batch .sampling_info .is_all_greedy :
287308 # temp == 0
288- candidates = draft_token [self .retrive_index ]
289309 bs = self .retrive_cum_len .numel () - 1
290310 predict = torch .argmax (logits_output .next_token_logits , dim = - 1 )
291311 predict = torch .cat (
@@ -316,13 +336,10 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
316336 )
317337 else :
318338 # temp > 0
319- candidates = draft_token [self .non_greedy_retrive_index ]
320- bs = self .non_greedy_retrive_index .shape [0 ]
339+ bs = self .retrive_index .shape [0 ]
321340 predict_shape = list (logits_output .next_token_logits .shape )[:- 1 ]
322341 predict_shape [- 1 ] += 1
323- target_logits = logits_output .next_token_logits [
324- self .non_greedy_retrive_index
325- ]
342+ target_logits = logits_output .next_token_logits [self .retrive_index ]
326343 predict = torch .full (predict_shape , - 1 , dtype = torch .int32 , device = "cuda" )
327344 accept_index = torch .full (
328345 (bs , self .spec_steps + 1 ), - 1 , dtype = torch .int32 , device = "cuda"
@@ -339,7 +356,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
339356 accept_index = accept_index , # mutable
340357 accept_token_num = accept_length , # mutable
341358 candidates = candidates .to (torch .int32 ),
342- retrive_index = self .non_greedy_retrive_index .to (torch .int32 ),
359+ retrive_index = self .retrive_index .to (torch .int32 ),
343360 retrive_next_token = self .retrive_next_token .to (torch .int32 ),
344361 retrive_next_sibling = self .retrive_next_sibling .to (torch .int32 ),
345362 uniform_samples = coins ,
0 commit comments