Skip to content

Commit d7fa9a8

Browse files
committed
non_blocking torch.zeros in _build_encoder_only_attn_metadata
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent f649899 commit d7fa9a8

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,10 @@ def extract_states(
617617
self,
618618
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
619619
pooling_metadata: PoolingMetadata,
620+
num_scheduled_tokens: torch.Tensor,
620621
) -> Union[list[torch.Tensor], torch.Tensor]:
621-
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
622+
pooled_data_lst = self.pooling(hidden_states, pooling_metadata,
623+
num_scheduled_tokens)
622624
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
623625

624626
pooled_data = list[torch.Tensor]()
@@ -652,7 +654,8 @@ def forward(
652654
pooling_metadata: PoolingMetadata,
653655
num_scheduled_tokens: torch.Tensor,
654656
) -> PoolerOutput:
655-
pooled_data = self.extract_states(hidden_states, pooling_metadata)
657+
pooled_data = self.extract_states(hidden_states, pooling_metadata,
658+
num_scheduled_tokens)
656659
pooled_data = self.head(pooled_data, pooling_metadata)
657660
return build_output(pooled_data)
658661

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3324,10 +3324,14 @@ def _build_encoder_only_attn_metadata(
33243324

33253325
dummy_block_table = torch.zeros((num_reqs, 1),
33263326
dtype=torch.int32,
3327-
device=self.device)
3327+
pin_memory=self.pin_memory,
3328+
device="cpu").to(self.device,
3329+
non_blocking=True)
33283330
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
33293331
dtype=torch.int32,
3330-
device=self.device)
3332+
pin_memory=self.pin_memory,
3333+
device="cpu").to(self.device,
3334+
non_blocking=True)
33313335

33323336
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
33333337

0 commit comments

Comments
 (0)