Skip to content

Commit 4388fac

Browse files
youngkenttjtanaa
authored andcommitted
[V1][Perf] Reduce scheduling overhead in model runner after cuda sync (vllm-project#12094)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
1 parent 2bc60ba commit 4388fac

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

vllm/v1/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class SamplerOutput:
99

1010
# [num_reqs]
11-
sampled_token_ids: List[int]
11+
sampled_token_ids: torch.Tensor
1212

1313
# [num_reqs, max_num_logprobs + 1]
1414
logprob_token_ids: Optional[torch.Tensor]

vllm/v1/sample/sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,8 @@ def forward(
5050
# Use int32 to reduce the tensor size.
5151
sampled = sampled.to(torch.int32)
5252

53-
# NOTE: CPU-GPU synchronization happens here.
5453
sampler_output = SamplerOutput(
55-
sampled_token_ids=sampled.tolist(),
54+
sampled_token_ids=sampled,
5655
logprob_token_ids=topk_indices,
5756
logprobs=topk_logprobs,
5857
prompt_logprob_token_ids=None,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -775,10 +775,10 @@ def execute_model(
775775
sampling_metadata=sampling_metadata,
776776
)
777777

778-
sampled_token_ids = sampler_output.sampled_token_ids
779778
# TODO(woosuk): The following loop can be slow since it iterates over
780779
# the requests one by one. Optimize.
781780
num_reqs = self.input_batch.num_reqs
781+
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
782782
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
783783
assert req_id is not None
784784
req_state = self.requests[req_id]
@@ -787,10 +787,10 @@ def execute_model(
787787
assert seq_len <= req_state.num_tokens
788788
if seq_len == req_state.num_tokens:
789789
# Append the sampled token to the output token ids.
790-
token_id = sampled_token_ids[i]
791-
self.input_batch.token_ids_cpu[i, seq_len] = token_id
792790
self.input_batch.num_tokens[i] += 1
793-
req_state.output_token_ids.append(token_id)
791+
# OPTIMIZATION: Priming the state updates for later updates.
792+
req_state.output_token_ids.append(0)
793+
request_seq_lens.append((i, req_state, seq_len))
794794
else:
795795
# Ignore the sampled token from the partial request.
796796
# Rewind the generator state as if the token was not sampled.
@@ -799,6 +799,21 @@ def execute_model(
799799
# This relies on cuda-specific torch-internal impl details
800800
generator.set_offset(generator.get_offset() - 4)
801801

802+
# num_reqs entries should be non-None
803+
assert all(
804+
req_id is not None for req_id in
805+
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
806+
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
807+
808+
# NOTE: GPU -> CPU Sync happens here.
809+
# Move as many CPU operations as possible before this sync point.
810+
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
811+
# Update with the actual token ids
812+
for i, req_state, seq_len in request_seq_lens:
813+
token_id = sampled_token_ids[i]
814+
self.input_batch.token_ids_cpu[i, seq_len] = token_id
815+
req_state.output_token_ids[-1] = token_id
816+
802817
if sampler_output.logprob_token_ids is None:
803818
logprob_token_ids = None
804819
else:
@@ -808,12 +823,6 @@ def execute_model(
808823
else:
809824
logprobs = sampler_output.logprobs.cpu()
810825

811-
# num_reqs entries should be non-None
812-
assert all(
813-
req_id is not None for req_id in
814-
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
815-
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
816-
817826
model_runner_output = ModelRunnerOutput(
818827
req_ids=req_ids,
819828
req_id_to_index=self.input_batch.req_id_to_index,

0 commit comments

Comments
 (0)