@@ -337,7 +337,7 @@ def __init__(
337337
338338 # Hidden states from target model to pass to proposer
339339 # in the subsequent step.
340- self .previous_hidden_states : Optional [HiddenStates ] = None
340+ self .previous_hidden_states : Dict [ int , Optional [HiddenStates ]] = {}
341341 self ._disable_logprobs = disable_logprobs
342342 self ._disable_log_stats = disable_log_stats
343343 self ._num_spec_prefill_steps = num_spec_prefill_steps
@@ -374,11 +374,13 @@ def init_device(self) -> None:
374374 self .proposer_worker .maybe_load_lm_head_weight (
375375 target_lm_head_weight )
376376
377- self ._metrics .init_tensors (self .rank , device_type = self .device )
378377 if model_parallel_is_initialized ():
378+ self ._metrics .init_tensors (get_tp_group ().rank_in_group ,
379+ device_type = self .device )
379380 self .spec_decode_sampler .init_tensors (get_tp_group ().local_rank ,
380381 device_type = self .device )
381382 else :
383+ self ._metrics .init_tensors (self .rank , device_type = self .device )
382384 self .spec_decode_sampler .init_tensors (self .rank ,
383385 device_type = self .device )
384386
@@ -467,7 +469,9 @@ def execute_model(
467469 ) -> List [SamplerOutput ]:
468470 """Perform speculative decoding on the input batch.
469471 """
470- if self .rank % self .tensor_parallel_size != self ._driver_rank :
472+ rank = get_tp_group ().rank_in_group if model_parallel_is_initialized (
473+ ) else self .rank
474+ if rank != self ._driver_rank :
471475 self ._run_non_driver_rank ()
472476 return []
473477
@@ -721,14 +725,19 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
721725 hidden_states = hidden_states [
722726 torch .where (sampler_output .sampled_token_ids -
723727 VLLM_INVALID_TOKEN_ID )[0 ]]
724- if self .previous_hidden_states is None and len (
725- seq_group_meta_with_hidden ):
726- self .previous_hidden_states = HiddenStates (
727- hidden_states , seq_group_meta_with_hidden )
728- elif self .previous_hidden_states and len (
729- seq_group_meta_with_hidden ):
730- self .previous_hidden_states .update (hidden_states ,
731- seq_group_meta_with_hidden )
728+ if execute_model_req .virtual_engine not in \
729+ self .previous_hidden_states and \
730+ len (seq_group_meta_with_hidden ):
731+ self .previous_hidden_states [
732+ execute_model_req .virtual_engine ] = HiddenStates (
733+ hidden_states , seq_group_meta_with_hidden )
734+ elif execute_model_req .virtual_engine in \
735+ self .previous_hidden_states and \
736+ len (seq_group_meta_with_hidden ):
737+ previous_hidden_states : HiddenStates = \
738+ self .previous_hidden_states [execute_model_req .virtual_engine ]
739+ previous_hidden_states .update (hidden_states ,
740+ seq_group_meta_with_hidden )
732741
733742 if not skip_proposer :
734743 # We prepare the prefill hidden states here so that there no
@@ -804,17 +813,15 @@ def _run_speculative_decoding_step(
804813 Returns a list of SamplerOutput, each containing a single token per
805814 sequence.
806815 """
807- if self .previous_hidden_states is not None :
808- self .previous_hidden_states .seq_group_metadata_list = execute_model_req .seq_group_metadata_list
809816 if get_pp_group ().is_first_rank :
810817 # With prefill chunking, expect requests to have prompts first
811818 # so that backend gets prefill|decode.
812819 assert num_lookahead_slots == execute_model_req .num_lookahead_slots
813820
814821 # Pass last hidden states from target model to proposer
815822 execute_model_req .previous_hidden_states = \
816- self .previous_hidden_states
817- self .previous_hidden_states = None
823+ self .previous_hidden_states [ execute_model_req . virtual_engine ]
824+ self .previous_hidden_states . pop ( execute_model_req . virtual_engine )
818825
819826 with Timer () as proposal_timer :
820827 # Generate proposals using draft worker.
@@ -883,8 +890,8 @@ def _run_speculative_decoding_step(
883890
884891 with Timer () as verification_timer :
885892 accepted_token_ids , target_logprobs = self ._verify_tokens (
886- execute_model_req . seq_group_metadata_list , proposal_scores ,
887- proposals , execute_model_req .num_lookahead_slots )
893+ execute_model_req , proposal_scores , proposals ,
894+ execute_model_req .num_lookahead_slots )
888895
889896 stage_times = (proposal_execute_time , scoring_timer .elapsed_time_ms ,
890897 verification_timer .elapsed_time_ms )
@@ -901,7 +908,7 @@ def _run_speculative_decoding_step(
901908 @nvtx_range ("spec_decode_worker._verify_tokens" )
902909 def _verify_tokens (
903910 self ,
904- seq_group_metadata_list : List [ SequenceGroupMetadata ] ,
911+ execute_model_req : ExecuteModelRequest ,
905912 proposal_scores : SpeculativeScores ,
906913 proposals : SpeculativeProposals ,
907914 max_proposal_len : int ,
@@ -912,6 +919,7 @@ def _verify_tokens(
912919 Returns a tuple of Tensors, one for the accepted token ids and one for
913920 the logprobs according to the scoring model.
914921 """
922+ seq_group_metadata_list = execute_model_req .seq_group_metadata_list
915923 proposal_lens_list = proposals .proposal_lens .tolist ()
916924
917925 # vLLM currently only supports proposal lens equal to zero or the batch
@@ -991,9 +999,10 @@ def _verify_tokens(
991999 second_last_token_hidden_states = hidden_states [:, - 2 ] # b x d
9921000 hidden_states = hidden_states .gather (1 , index ).squeeze (1 ) # b x d
9931001 # Store hidden states from target model for subsequent decode step
994- self .previous_hidden_states = HiddenStates (
995- hidden_states , terminal_metadata ,
996- second_last_token_hidden_states )
1002+ self .previous_hidden_states [
1003+ execute_model_req .virtual_engine ] = HiddenStates (
1004+ hidden_states , terminal_metadata ,
1005+ second_last_token_hidden_states )
9971006 return accepted_token_ids , logprobs
9981007
9991008 def _create_output_sampler_list (
0 commit comments