Skip to content

Commit 8ff2418

Browse files
committed
reduce cuda sync
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 6d8f55e commit 8ff2418

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
196196

197197
def build_output(
198198
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
199+
# Pooling models D2H occurs only here
200+
if isinstance(all_data, list):
201+
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
202+
else:
203+
all_data = all_data.to("cpu", non_blocking=True)
204+
torch.cuda.synchronize()
205+
199206
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
200207
return PoolerOutput(outputs=all_outputs)
201208

@@ -706,6 +713,7 @@ def forward(
706713
hidden_states_lst = hidden_states
707714
else:
708715
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
716+
709717
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
710718

711719
outputs = list[PoolingSequenceGroupOutput]()

vllm/v1/worker/gpu_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def pooling_metadata(self) -> PoolingMetadata:
713713

714714
return PoolingMetadata(
715715
prompt_lens=torch.from_numpy(
716-
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
716+
self.num_prompt_tokens[:self.num_reqs]),
717717
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
718718
pooling_params=pooling_params,
719719
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,7 @@ def _pool(
14891489
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
14901490

14911491
if seq_len == prompt_len:
1492-
pooler_output.append(raw_output.data.cpu())
1492+
pooler_output.append(raw_output.data)
14931493
else:
14941494
pooler_output.append(None)
14951495

0 commit comments

Comments
 (0)