Skip to content

Commit 4bd0b84

Browse files
committed
update
1 parent 1b2d78b commit 4bd0b84

File tree

3 files changed

+61
-40
lines changed

3 files changed

+61
-40
lines changed

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,11 @@ def get_spec_info(self, num_tokens: int):
462462
),
463463
positions=None,
464464
retrive_index=None,
465+
retrive_next_token=None,
466+
retrive_next_sibling=None,
465467
retrive_cum_len=None,
466468
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
469+
spec_steps=self.model_runner.server_args.speculative_num_steps,
467470
capture_hidden_mode=CaptureHiddenMode.FULL,
468471
)
469472

python/sglang/srt/speculative/eagle_utils.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,12 @@ class EagleVerifyInput:
168168
custom_mask: torch.Tensor
169169
positions: torch.Tensor
170170
retrive_index: torch.Tensor
171+
retrive_next_token: torch.Tensor
172+
retrive_next_sibling: torch.Tensor
171173
retrive_cum_len: torch.Tensor
172174
draft_token_num: int
175+
spec_steps: int
173176
capture_hidden_mode: CaptureHiddenMode
174-
spec_steps: int = 0
175-
retrive_next_token: torch.Tensor = None
176-
retrive_next_sibling: torch.Tensor = None
177-
non_greedy_retrive_index: torch.Tensor = None
178177

179178
@classmethod
180179
def create(
@@ -187,23 +186,45 @@ def create(
187186
seq_lens_sum: int,
188187
topk: int,
189188
spec_steps: int,
190-
num_verify_token: int,
189+
num_verify_tokens: int,
190+
is_all_greedy: bool,
191191
):
192-
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
193-
build_tree_kernel(
194-
verified_id,
195-
score_list,
196-
token_list,
197-
parents_list,
198-
seq_lens,
199-
seq_lens_sum,
200-
topk,
192+
if is_all_greedy:
193+
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
194+
build_tree_kernel(
195+
verified_id,
196+
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
197+
token_list,
198+
parents_list,
199+
seq_lens,
200+
seq_lens_sum,
201+
topk,
202+
spec_steps,
203+
num_verify_tokens,
204+
)
205+
)
206+
207+
return cls(
208+
draft_tokens,
209+
tree_mask,
210+
position,
211+
retrive_index,
212+
None,
213+
None,
214+
retrive_cum_len,
215+
num_verify_tokens,
201216
spec_steps,
202-
num_verify_token,
217+
CaptureHiddenMode.FULL,
203218
)
204-
)
205-
_, _, non_greedy_retrive_index, retrive_next_token, retrive_next_sibling, _ = (
206-
build_tree_kernel_efficient(
219+
else:
220+
(
221+
tree_mask,
222+
position,
223+
retrive_index,
224+
retrive_next_token,
225+
retrive_next_sibling,
226+
draft_tokens,
227+
) = build_tree_kernel_efficient(
207228
verified_id,
208229
score_list,
209230
token_list,
@@ -212,22 +233,21 @@ def create(
212233
seq_lens_sum,
213234
topk,
214235
spec_steps,
215-
num_verify_token,
236+
num_verify_tokens,
237+
)
238+
239+
return cls(
240+
draft_tokens,
241+
tree_mask,
242+
position,
243+
retrive_index,
244+
retrive_next_token,
245+
retrive_next_sibling,
246+
None,
247+
num_verify_tokens,
248+
spec_steps,
249+
CaptureHiddenMode.FULL,
216250
)
217-
)
218-
return cls(
219-
draft_tokens,
220-
tree_mask,
221-
position,
222-
retrive_index,
223-
retrive_cum_len,
224-
num_verify_token,
225-
CaptureHiddenMode.FULL,
226-
spec_steps,
227-
retrive_next_token,
228-
retrive_next_sibling,
229-
non_greedy_retrive_index,
230-
)
231251

232252
def prepare_for_verify(self, batch: ScheduleBatch):
233253
batch.input_ids = self.draft_token
@@ -283,9 +303,9 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
283303
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
284304
dim=-1,
285305
)
306+
candidates = draft_token[self.retrive_index]
286307
if batch.sampling_info.is_all_greedy:
287308
# temp == 0
288-
candidates = draft_token[self.retrive_index]
289309
bs = self.retrive_cum_len.numel() - 1
290310
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
291311
predict = torch.cat(
@@ -316,13 +336,10 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
316336
)
317337
else:
318338
# temp > 0
319-
candidates = draft_token[self.non_greedy_retrive_index]
320-
bs = self.non_greedy_retrive_index.shape[0]
339+
bs = self.retrive_index.shape[0]
321340
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
322341
predict_shape[-1] += 1
323-
target_logits = logits_output.next_token_logits[
324-
self.non_greedy_retrive_index
325-
]
342+
target_logits = logits_output.next_token_logits[self.retrive_index]
326343
predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
327344
accept_index = torch.full(
328345
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
@@ -339,7 +356,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
339356
accept_index=accept_index, # mutable
340357
accept_token_num=accept_length, # mutable
341358
candidates=candidates.to(torch.int32),
342-
retrive_index=self.non_greedy_retrive_index.to(torch.int32),
359+
retrive_index=self.retrive_index.to(torch.int32),
343360
retrive_next_token=self.retrive_next_token.to(torch.int32),
344361
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
345362
uniform_samples=coins,

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def draft(self, batch: ScheduleBatch):
185185
self.topk,
186186
self.speculative_num_steps,
187187
self.server_args.speculative_num_draft_tokens,
188+
batch.sampling_info.is_all_greedy,
188189
)
189190

190191
# Free cache locations

0 commit comments

Comments
 (0)