@@ -325,24 +325,29 @@ def steps_spec_dec(self) -> List[Sequence]:
325325 List[Sequence]: finished sequences generated by one step.
326326 """
327327 batch = self .request_handler .schedule () # prefill batch
328-
329328 assert batch .current_batch_size == 1 , "Only support bsz 1 for speculative decoding for now."
330- input_ids = batch .get_1D_inputs () # bsz 1 for drafter model
329+
330+ input_token_ids , output_tensor , input_meta_data = self .prepare_input (batch )
331+
332+ if input_meta_data .use_cuda_graph :
333+ model_executable = self .graph_runners [input_meta_data .batch_size ]
334+ else :
335+ model_executable = self .model
331336
332337 # 1. Prefill small model (Drafter) - fill past kv cache for drafter model
333338 # NOTE For glide drafter models, we won't actually apply glide during prefill stage
334- drafter_out = self .drafter .speculate (input_ids , 1 , None )
339+ drafter_out = self .drafter .speculate (input_token_ids , 1 , None )
335340 next_token_ids_spec = drafter_out .next_tokens
336341 drafter_past_key_values = drafter_out .past_key_values
337342
338343 # 2. Prefill main model (Verifier) - fill past kv cache for main model
339- logits = self . model ( batch , self .k_cahce , self .v_cache )
344+ logits = model_executable ( input_token_ids , output_tensor , input_meta_data , self .k_cache , self .v_cache )
340345 next_tokens = self .request_handler .search_tokens (self .generation_config , logits )
341346 # append new inputs to the batch, temporarily
342347 batch .append_batch_tokens (next_tokens )
343348 self .request_handler .allocate_batch_spec_dec (batch , 1 )
344349 already_allocated_kv_len = batch .seq_lengths [0 ].item ()
345- input_ids = batch .get_1D_inputs_spec_dec (1 )
350+ input_token_ids = batch .get_1D_inputs_spec_dec (1 )
346351
347352 finished_sequences = self .request_handler .update ()
348353
@@ -357,13 +362,13 @@ def steps_spec_dec(self) -> List[Sequence]:
357362 if self .use_glide :
358363 glide_input = GlideInput (
359364 batch .get_block_table_tensor (),
360- self .k_cahce [- 1 ], # use kv cahces of the last layer
365+ self .k_cache [- 1 ], # use kv cahces of the last layer
361366 self .v_cache [- 1 ],
362367 batch .get_sequence_lengths (),
363368 )
364369
365370 drafter_out = self .drafter .speculate (
366- input_ids ,
371+ input_token_ids ,
367372 self .n_spec_tokens ,
368373 drafter_past_key_values ,
369374 glide_input = glide_input ,
@@ -382,7 +387,9 @@ def steps_spec_dec(self) -> List[Sequence]:
382387 # 4. Decoding - Main model verifies `n` tokens in parallel
383388 if drafter_spec_length < batch .num_tokens_to_verify :
384389 batch .set_use_spec_dec (num_tokens_to_verify = drafter_spec_length )
385- logits = self .model (batch , self .k_cahce , self .v_cache )
390+ input_token_ids , output_tensor , input_meta_data = self .prepare_input (batch )
391+ logits = model_executable (input_token_ids , output_tensor , input_meta_data , self .k_cache , self .v_cache )
392+
386393 next_tokens = self .request_handler .search_tokens (self .generation_config , logits )
387394
388395 # 5. Compare and process the results
@@ -402,7 +409,7 @@ def steps_spec_dec(self) -> List[Sequence]:
402409
403410 # prepare inputs for the next round of speculation
404411 n = 1 if n_matches < drafter_spec_length else 2
405- input_ids = batch .get_1D_inputs_spec_dec (n )
412+ input_token_ids = batch .get_1D_inputs_spec_dec (n )
406413
407414 self .request_handler .update_batch_finished (batch , generation_config = self .generation_config )
408415 finished_sequences = self .request_handler .update ()
@@ -564,18 +571,19 @@ def add_request(
564571
565572 def prepare_input (self , batch : BatchBucket ) -> Tuple [torch .Tensor , torch .Tensor , InputMetaData ]:
566573 input_ids = batch .get_1D_inputs ()
567-
568574 sequence_lengths = batch .get_sequence_lengths ()
575+
569576 if batch .is_prompts :
570- output_tensor = torch .zeros (
571- (sequence_lengths .sum ().item (), batch .num_heads * batch .head_dim ),
572- dtype = batch .dtype ,
573- device = batch .device ,
574- )
577+ n_tokens = sequence_lengths .sum ().item ()
575578 else :
576- output_tensor = torch .zeros (
577- (batch .current_batch_size , batch .num_heads * batch .head_dim ), dtype = batch .dtype , device = batch .device
578- )
579+ n_tokens = batch .current_batch_size
580+ if batch .use_spec_dec :
581+ n_tokens = batch .num_tokens_to_verify + 1
582+ assert n_tokens == input_ids .size (0 )
583+ n_tokens = n_tokens * batch .current_batch_size
584+ output_tensor = torch .zeros (
585+ (n_tokens , batch .num_heads * batch .head_dim ), dtype = batch .dtype , device = batch .device
586+ )
579587
580588 # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
581589 use_cuda_graph = False
@@ -594,6 +602,8 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
594602 kv_seq_len = sequence_lengths .max ().item (),
595603 head_dim = batch .head_dim ,
596604 dtype = batch .dtype ,
605+ use_spec_dec = batch .use_spec_dec ,
606+ num_tokens_to_verify = batch .num_tokens_to_verify ,
597607 )
598608
599609 return input_ids , output_tensor , input_meta_data
0 commit comments