@@ -775,10 +775,10 @@ def execute_model(
775
775
sampling_metadata = sampling_metadata ,
776
776
)
777
777
778
- sampled_token_ids = sampler_output .sampled_token_ids
779
778
# TODO(woosuk): The following loop can be slow since it iterates over
780
779
# the requests one by one. Optimize.
781
780
num_reqs = self .input_batch .num_reqs
781
+ request_seq_lens : List [Tuple [int , CachedRequestState , int ]] = []
782
782
for i , req_id in enumerate (self .input_batch .req_ids [:num_reqs ]):
783
783
assert req_id is not None
784
784
req_state = self .requests [req_id ]
@@ -787,10 +787,10 @@ def execute_model(
787
787
assert seq_len <= req_state .num_tokens
788
788
if seq_len == req_state .num_tokens :
789
789
# Append the sampled token to the output token ids.
790
- token_id = sampled_token_ids [i ]
791
- self .input_batch .token_ids_cpu [i , seq_len ] = token_id
792
790
self .input_batch .num_tokens [i ] += 1
793
- req_state .output_token_ids .append (token_id )
791
+ # OPTIMIZATION: Priming the state updates for later updates.
792
+ req_state .output_token_ids .append (0 )
793
+ request_seq_lens .append ((i , req_state , seq_len ))
794
794
else :
795
795
# Ignore the sampled token from the partial request.
796
796
# Rewind the generator state as if the token was not sampled.
@@ -799,6 +799,21 @@ def execute_model(
799
799
# This relies on cuda-specific torch-internal impl details
800
800
generator .set_offset (generator .get_offset () - 4 )
801
801
802
+ # num_reqs entries should be non-None
803
+ assert all (
804
+ req_id is not None for req_id in
805
+ self .input_batch .req_ids [:num_reqs ]), "req_ids contains None"
806
+ req_ids = cast (List [str ], self .input_batch .req_ids [:num_reqs ])
807
+
808
+ # NOTE: GPU -> CPU Sync happens here.
809
+ # Move as many CPU operations as possible before this sync point.
810
+ sampled_token_ids = sampler_output .sampled_token_ids .tolist ()
811
+ # Update with the actual token ids
812
+ for i , req_state , seq_len in request_seq_lens :
813
+ token_id = sampled_token_ids [i ]
814
+ self .input_batch .token_ids_cpu [i , seq_len ] = token_id
815
+ req_state .output_token_ids [- 1 ] = token_id
816
+
802
817
if sampler_output .logprob_token_ids is None :
803
818
logprob_token_ids = None
804
819
else :
@@ -808,12 +823,6 @@ def execute_model(
808
823
else :
809
824
logprobs = sampler_output .logprobs .cpu ()
810
825
811
- # num_reqs entries should be non-None
812
- assert all (
813
- req_id is not None for req_id in
814
- self .input_batch .req_ids [:num_reqs ]), "req_ids contains None"
815
- req_ids = cast (List [str ], self .input_batch .req_ids [:num_reqs ])
816
-
817
826
model_runner_output = ModelRunnerOutput (
818
827
req_ids = req_ids ,
819
828
req_id_to_index = self .input_batch .req_id_to_index ,
0 commit comments