Skip to content

Commit 5b0ceea

Browse files
noooopkylesayrs
authored andcommitted
[Performance] V1 Pooling Models E2E Performance Optimization (vllm-project#23162)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent c51acf8 commit 5b0ceea

File tree

8 files changed

+161
-167
lines changed

8 files changed

+161
-167
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 48 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from vllm.pooling_params import PoolingParams
2020
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
2121
from vllm.tasks import PoolingTask
22-
from vllm.utils import resolve_obj_by_qualname
22+
from vllm.utils import current_stream, resolve_obj_by_qualname
23+
from vllm.v1.pool.metadata import PoolingCursor
2324
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2425

2526
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
@@ -205,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
205206

206207
def build_output(
207208
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
209+
# Pooling models D2H & synchronize occurs here
210+
if isinstance(all_data, list):
211+
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
212+
else:
213+
all_data = all_data.to("cpu", non_blocking=True)
214+
current_stream().synchronize()
215+
208216
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
209217
return PoolerOutput(outputs=all_outputs)
210218

@@ -231,141 +239,96 @@ def get_supported_tasks(self) -> Set[PoolingTask]:
231239
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
232240
return PoolingParamsUpdate()
233241

234-
@abstractmethod
235-
def forward_one(
236-
self,
237-
hidden_states: torch.Tensor,
238-
prompt_len: Optional[torch.Tensor] = None,
239-
) -> torch.Tensor:
240-
"""
241-
Note:
242-
`prompt_len=None` means `prompt_len=len(hidden_states)`.
243-
"""
244-
raise NotImplementedError
245-
246242
@abstractmethod
247243
def forward_all(
248244
self,
249245
hidden_states: torch.Tensor,
250-
prompt_lens: torch.Tensor,
246+
pooling_cursor: PoolingCursor,
251247
) -> Union[list[torch.Tensor], torch.Tensor]:
252248
raise NotImplementedError
253249

254250
def forward(
255251
self,
256-
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
252+
hidden_states: torch.Tensor,
257253
pooling_metadata: PoolingMetadata,
258254
) -> Union[list[torch.Tensor], torch.Tensor]:
259-
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
260-
261-
if isinstance(hidden_states, list):
262-
return [
263-
self.forward_one(h, prompt_len)
264-
for h, prompt_len in zip(hidden_states, prompt_lens)
265-
]
266-
267-
return self.forward_all(hidden_states, prompt_lens)
255+
pooling_cursor = pooling_metadata.pooling_cursor
256+
return self.forward_all(hidden_states, pooling_cursor)
268257

269258

270259
class CLSPool(PoolingMethod):
271260

272261
def get_supported_tasks(self) -> Set[PoolingTask]:
273262
return {"encode", "embed", "classify", "score"}
274263

275-
def forward_one(
276-
self,
277-
hidden_states: torch.Tensor,
278-
prompt_len: Optional[torch.Tensor] = None,
279-
) -> torch.Tensor:
280-
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
281-
"partial prefill not supported with CLS pooling"
282-
283-
return hidden_states[0]
284-
285264
def forward_all(
286265
self,
287266
hidden_states: torch.Tensor,
288-
prompt_lens: torch.Tensor,
267+
pooling_cursor: PoolingCursor,
289268
) -> Union[list[torch.Tensor], torch.Tensor]:
290-
first_token_flat_indices = torch.zeros_like(prompt_lens)
291-
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
292-
return hidden_states[first_token_flat_indices]
269+
assert not pooling_cursor.is_partial_prefill(), \
270+
"partial prefill not supported with CLS pooling"
271+
272+
return hidden_states[pooling_cursor.first_token_indices_gpu]
293273

294274

295275
class LastPool(PoolingMethod):
296276

297277
def get_supported_tasks(self) -> Set[PoolingTask]:
298278
return {"encode", "embed", "classify", "score"}
299279

300-
def forward_one(
301-
self,
302-
hidden_states: torch.Tensor,
303-
prompt_len: Optional[torch.Tensor] = None,
304-
) -> torch.Tensor:
305-
return hidden_states[-1]
306-
307280
def forward_all(
308281
self,
309282
hidden_states: torch.Tensor,
310-
prompt_lens: torch.Tensor,
283+
pooling_cursor: PoolingCursor,
311284
) -> Union[list[torch.Tensor], torch.Tensor]:
312-
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
313-
return hidden_states[last_token_flat_indices]
285+
return hidden_states[pooling_cursor.last_token_indices_gpu]
314286

315287

316288
class AllPool(PoolingMethod):
317289

318290
def get_supported_tasks(self) -> Set[PoolingTask]:
319291
return {"encode"}
320292

321-
def forward_one(
322-
self,
323-
hidden_states: torch.Tensor,
324-
prompt_len: Optional[torch.Tensor] = None,
325-
) -> torch.Tensor:
326-
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
327-
"partial prefill not supported with ALL pooling"
328-
329-
return hidden_states
330-
331293
def forward_all(
332294
self,
333295
hidden_states: torch.Tensor,
334-
prompt_lens: torch.Tensor,
296+
pooling_cursor: PoolingCursor,
335297
) -> Union[list[torch.Tensor], torch.Tensor]:
336-
return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
298+
299+
assert not pooling_cursor.is_partial_prefill(), \
300+
"partial prefill not supported with ALL pooling"
301+
302+
hidden_states_lst = list(
303+
hidden_states.split(
304+
pooling_cursor.num_scheduled_tokens_cpu.tolist()))
305+
return [hidden_states_lst[i] for i in pooling_cursor.index]
337306

338307

339308
class MeanPool(PoolingMethod):
340309

341310
def get_supported_tasks(self) -> Set[PoolingTask]:
342311
return {"encode", "embed", "classify", "score"}
343312

344-
def forward_one(
313+
def forward_all(
345314
self,
346315
hidden_states: torch.Tensor,
347-
prompt_len: Optional[torch.Tensor] = None,
348-
) -> torch.Tensor:
349-
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
316+
pooling_cursor: PoolingCursor,
317+
) -> Union[list[torch.Tensor], torch.Tensor]:
318+
319+
assert not pooling_cursor.is_partial_prefill(), \
350320
"partial prefill not supported with MEAN pooling"
351321

352-
return hidden_states.mean(dim=0, dtype=torch.float32)
322+
prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
323+
non_blocking=True)
353324

354-
def forward_all(
355-
self,
356-
hidden_states: torch.Tensor,
357-
prompt_lens: torch.Tensor,
358-
) -> Union[list[torch.Tensor], torch.Tensor]:
359325
# Use float32 for torch.cumsum in MeanPool,
360326
# otherwise precision will be lost significantly.
361327
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
362328

363-
start_indices = torch.cat([
364-
torch.tensor([0], device=hidden_states.device),
365-
torch.cumsum(prompt_lens[:-1], dim=0)
366-
])
367-
end_indices = torch.cumsum(prompt_lens, dim=0)
368-
return (cumsum[end_indices - 1] - cumsum[start_indices] +
329+
start_indices = pooling_cursor.first_token_indices_gpu
330+
end_indices = pooling_cursor.last_token_indices_gpu
331+
return (cumsum[end_indices] - cumsum[start_indices] +
369332
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
370333

371334

@@ -477,6 +440,10 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
477440

478441
pooling_params = get_pooling_params(pooling_metadata)
479442

443+
if isinstance(pooled_data, list):
444+
pooled_data = torch.stack(pooled_data)
445+
# pooled_data shape: [batchsize, embedding_dimension]
446+
480447
# for matryoshka representation
481448
dimensions_list = [
482449
pooling_param.dimensions for pooling_param in pooling_params
@@ -667,6 +634,10 @@ def forward(
667634
) -> PoolerOutput:
668635
pooled_data = self.pooling(hidden_states, pooling_metadata)
669636

637+
if isinstance(pooled_data, list):
638+
pooled_data = torch.stack(pooled_data)
639+
# pooled_data shape: [batchsize, hidden_size]
640+
670641
if self.classifier is not None:
671642
# apply classifier once on the full batch if possible
672643
if isinstance(pooled_data, torch.Tensor):
@@ -717,12 +688,6 @@ def forward(
717688
) -> PoolerOutput:
718689
poolers_by_task = self.poolers_by_task
719690

720-
if isinstance(hidden_states, list):
721-
hidden_states_lst = hidden_states
722-
else:
723-
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
724-
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
725-
726691
outputs = list[PoolingSequenceGroupOutput]()
727692
offset = 0
728693
for task, group in groupby(get_tasks(pooling_metadata)):
@@ -733,7 +698,7 @@ def forward(
733698

734699
num_items = len(list(group))
735700
group_output: PoolerOutput = pooler(
736-
hidden_states_lst[offset:offset + num_items],
701+
hidden_states,
737702
pooling_metadata[offset:offset + num_items],
738703
)
739704

vllm/model_executor/models/bert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,9 @@ def _encode_token_type_ids(input_ids: torch.Tensor,
528528

529529
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
530530

531-
ids_mask = torch.ones(input_ids.shape,
532-
dtype=torch.int32,
533-
device=input_ids.device) << TOKEN_TYPE_SHIFT
531+
ids_mask = torch.ones_like(input_ids,
532+
dtype=torch.int32,
533+
device=input_ids.device) << TOKEN_TYPE_SHIFT
534534
tokens_mask = ids_mask.bitwise_not()
535535

536536
token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT

vllm/model_executor/models/roberta.py

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from transformers import RobertaConfig
1010

1111
from vllm.config import VllmConfig
12-
from vllm.forward_context import get_forward_context
1312
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
1413
DispatchPooler, Pooler)
1514
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -100,7 +99,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
10099

101100
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
102101
super().__init__(vllm_config=vllm_config, prefix=prefix)
103-
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
102+
self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
104103

105104
def forward(
106105
self,
@@ -178,7 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
178177
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
179178
super().__init__()
180179
config = vllm_config.model_config.hf_config
181-
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
180+
self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
182181

183182
self.num_labels = config.num_labels
184183
self.roberta = BertModel(vllm_config=vllm_config,
@@ -233,58 +232,14 @@ def forward(
233232
intermediate_tensors=intermediate_tensors)
234233

235234

236-
# Adapted from transformers
237-
def create_position_ids_from_input_ids(input_ids,
238-
padding_idx,
239-
past_key_values_length=0):
240-
"""
241-
Replace non-padding symbols with their position numbers.
242-
Position numbers begin at padding_idx+1. Padding symbols
243-
are ignored. This is modified from fairseq's `utils.make_positions`.
244-
245-
Args:
246-
x: torch.Tensor x:
247-
248-
Returns: torch.Tensor
249-
"""
250-
# The series of casts and type-conversions here are carefully
251-
# balanced to both work with ONNX export and XLA.
252-
mask = input_ids.ne(padding_idx).int()
253-
254-
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
255-
past_key_values_length) * mask
256-
257-
return incremental_indices.long() + padding_idx
258-
259-
260235
def replace_roberta_positions(input_ids: torch.Tensor,
261236
position_ids: torch.Tensor,
262237
padding_idx: int) -> None:
263-
264-
seq_lens: Optional[torch.Tensor] = None
265-
attn_metadata = get_forward_context().attn_metadata
266-
if attn_metadata is not None: # can be None during warmup
267-
if isinstance(attn_metadata, dict):
268-
attn_metadata = next(iter(attn_metadata.values()))
269-
# TODO: remove "seq_lens_tensor" after V0 is removed
270-
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
271-
getattr(attn_metadata, "seq_lens", None))
272-
273-
if seq_lens is not None:
274-
assert isinstance(seq_lens, torch.Tensor)
275-
276-
# Replace position ids because in RoBERTa models
277-
# they have to start at padding_idx + 1 and ignore
278-
# existing padding tokens
279-
# References:
280-
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
281-
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
282-
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
283-
seq_lens.tolist())
284-
285-
offset = 0
286-
for tokens in token_list:
287-
length = tokens.shape[0]
288-
position_ids[offset:offset+length] = \
289-
create_position_ids_from_input_ids(tokens, padding_idx)
290-
offset = offset + length
238+
# Replace position ids because in RoBERTa models
239+
# they have to start at padding_idx + 1 and ignore
240+
# existing padding tokens
241+
# References:
242+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
243+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
244+
# vllm does not use padding tokens, let's make things simpler
245+
position_ids += padding_idx + 1

vllm/model_executor/pooling_metadata.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5-
from typing import Any
5+
from typing import Any, Optional
66

77
import torch
88

99
from vllm.pooling_params import PoolingParams
1010
from vllm.utils import is_pin_memory_available
11+
from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor
1112

1213

1314
class PoolingMetadata:
@@ -23,14 +24,15 @@ class PoolingMetadata:
2324
"""
2425

2526
def __init__(
26-
self,
27-
seq_groups: list[tuple[list[int], PoolingParams]],
28-
seq_data: dict[int, Any], # Specific data related to sequences
29-
prompt_lens: list[int],
30-
) -> None:
27+
self,
28+
seq_groups: list[tuple[list[int], PoolingParams]],
29+
seq_data: dict[int, Any], # Specific data related to sequences
30+
prompt_lens: list[int],
31+
pooling_cursor: Optional[PoolingCursor] = None) -> None:
3132
self.seq_groups = seq_groups
3233
self.seq_data = seq_data
3334
self.prompt_lens = prompt_lens
35+
self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor
3436

3537
def __repr__(self) -> str:
3638
return ("PoolingMetadata("
@@ -43,8 +45,17 @@ def __getitem__(self, indices: slice):
4345
seq_groups=self.seq_groups[indices],
4446
seq_data=dict(list(self.seq_data.items())[indices]),
4547
prompt_lens=self.prompt_lens[indices],
48+
pooling_cursor=None
49+
if self.pooling_cursor is None else self.pooling_cursor[indices],
4650
)
4751

52+
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
53+
device: torch.device):
54+
prompt_lens = torch.tensor(self.prompt_lens, device="cpu")
55+
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
56+
prompt_lens,
57+
device=device)
58+
4859

4960
@dataclass
5061
class PoolingTensors:

0 commit comments

Comments
 (0)