19
19
from vllm .pooling_params import PoolingParams
20
20
from vllm .sequence import PoolerOutput , PoolingSequenceGroupOutput
21
21
from vllm .tasks import PoolingTask
22
- from vllm .utils import resolve_obj_by_qualname
22
+ from vllm .utils import current_stream , resolve_obj_by_qualname
23
+ from vllm .v1 .pool .metadata import PoolingCursor
23
24
from vllm .v1 .pool .metadata import PoolingMetadata as V1PoolingMetadata
24
25
25
26
PoolingMetadata = Union [V0PoolingMetadata , V1PoolingMetadata ]
@@ -205,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
205
206
206
207
def build_output (
207
208
all_data : Union [torch .Tensor , list [torch .Tensor ]], ) -> PoolerOutput :
209
+ # Pooling models D2H & synchronize occurs here
210
+ if isinstance (all_data , list ):
211
+ all_data = [d .to ("cpu" , non_blocking = True ) for d in all_data ]
212
+ else :
213
+ all_data = all_data .to ("cpu" , non_blocking = True )
214
+ current_stream ().synchronize ()
215
+
208
216
all_outputs = [PoolingSequenceGroupOutput (data ) for data in all_data ]
209
217
return PoolerOutput (outputs = all_outputs )
210
218
@@ -231,141 +239,96 @@ def get_supported_tasks(self) -> Set[PoolingTask]:
231
239
def get_pooling_updates (self , task : PoolingTask ) -> PoolingParamsUpdate :
232
240
return PoolingParamsUpdate ()
233
241
234
- @abstractmethod
235
- def forward_one (
236
- self ,
237
- hidden_states : torch .Tensor ,
238
- prompt_len : Optional [torch .Tensor ] = None ,
239
- ) -> torch .Tensor :
240
- """
241
- Note:
242
- `prompt_len=None` means `prompt_len=len(hidden_states)`.
243
- """
244
- raise NotImplementedError
245
-
246
242
@abstractmethod
247
243
def forward_all (
248
244
self ,
249
245
hidden_states : torch .Tensor ,
250
- prompt_lens : torch . Tensor ,
246
+ pooling_cursor : PoolingCursor ,
251
247
) -> Union [list [torch .Tensor ], torch .Tensor ]:
252
248
raise NotImplementedError
253
249
254
250
def forward (
255
251
self ,
256
- hidden_states : Union [ torch .Tensor , list [ torch . Tensor ]] ,
252
+ hidden_states : torch .Tensor ,
257
253
pooling_metadata : PoolingMetadata ,
258
254
) -> Union [list [torch .Tensor ], torch .Tensor ]:
259
- prompt_lens = get_prompt_lens (hidden_states , pooling_metadata )
260
-
261
- if isinstance (hidden_states , list ):
262
- return [
263
- self .forward_one (h , prompt_len )
264
- for h , prompt_len in zip (hidden_states , prompt_lens )
265
- ]
266
-
267
- return self .forward_all (hidden_states , prompt_lens )
255
+ pooling_cursor = pooling_metadata .pooling_cursor
256
+ return self .forward_all (hidden_states , pooling_cursor )
268
257
269
258
270
259
class CLSPool (PoolingMethod ):
271
260
272
261
def get_supported_tasks (self ) -> Set [PoolingTask ]:
273
262
return {"encode" , "embed" , "classify" , "score" }
274
263
275
- def forward_one (
276
- self ,
277
- hidden_states : torch .Tensor ,
278
- prompt_len : Optional [torch .Tensor ] = None ,
279
- ) -> torch .Tensor :
280
- assert prompt_len is None or prompt_len == hidden_states .shape [0 ], \
281
- "partial prefill not supported with CLS pooling"
282
-
283
- return hidden_states [0 ]
284
-
285
264
def forward_all (
286
265
self ,
287
266
hidden_states : torch .Tensor ,
288
- prompt_lens : torch . Tensor ,
267
+ pooling_cursor : PoolingCursor ,
289
268
) -> Union [list [torch .Tensor ], torch .Tensor ]:
290
- first_token_flat_indices = torch .zeros_like (prompt_lens )
291
- first_token_flat_indices [1 :] += torch .cumsum (prompt_lens , dim = 0 )[:- 1 ]
292
- return hidden_states [first_token_flat_indices ]
269
+ assert not pooling_cursor .is_partial_prefill (), \
270
+ "partial prefill not supported with CLS pooling"
271
+
272
+ return hidden_states [pooling_cursor .first_token_indices_gpu ]
293
273
294
274
295
275
class LastPool (PoolingMethod ):
296
276
297
277
def get_supported_tasks (self ) -> Set [PoolingTask ]:
298
278
return {"encode" , "embed" , "classify" , "score" }
299
279
300
- def forward_one (
301
- self ,
302
- hidden_states : torch .Tensor ,
303
- prompt_len : Optional [torch .Tensor ] = None ,
304
- ) -> torch .Tensor :
305
- return hidden_states [- 1 ]
306
-
307
280
def forward_all (
308
281
self ,
309
282
hidden_states : torch .Tensor ,
310
- prompt_lens : torch . Tensor ,
283
+ pooling_cursor : PoolingCursor ,
311
284
) -> Union [list [torch .Tensor ], torch .Tensor ]:
312
- last_token_flat_indices = torch .cumsum (prompt_lens , dim = 0 ) - 1
313
- return hidden_states [last_token_flat_indices ]
285
+ return hidden_states [pooling_cursor .last_token_indices_gpu ]
314
286
315
287
316
288
class AllPool (PoolingMethod ):
317
289
318
290
def get_supported_tasks (self ) -> Set [PoolingTask ]:
319
291
return {"encode" }
320
292
321
- def forward_one (
322
- self ,
323
- hidden_states : torch .Tensor ,
324
- prompt_len : Optional [torch .Tensor ] = None ,
325
- ) -> torch .Tensor :
326
- assert prompt_len is None or prompt_len == hidden_states .shape [0 ], \
327
- "partial prefill not supported with ALL pooling"
328
-
329
- return hidden_states
330
-
331
293
def forward_all (
332
294
self ,
333
295
hidden_states : torch .Tensor ,
334
- prompt_lens : torch . Tensor ,
296
+ pooling_cursor : PoolingCursor ,
335
297
) -> Union [list [torch .Tensor ], torch .Tensor ]:
336
- return list (hidden_states .split_with_sizes (prompt_lens .tolist ()))
298
+
299
+ assert not pooling_cursor .is_partial_prefill (), \
300
+ "partial prefill not supported with ALL pooling"
301
+
302
+ hidden_states_lst = list (
303
+ hidden_states .split (
304
+ pooling_cursor .num_scheduled_tokens_cpu .tolist ()))
305
+ return [hidden_states_lst [i ] for i in pooling_cursor .index ]
337
306
338
307
339
308
class MeanPool (PoolingMethod ):
340
309
341
310
def get_supported_tasks (self ) -> Set [PoolingTask ]:
342
311
return {"encode" , "embed" , "classify" , "score" }
343
312
344
- def forward_one (
313
+ def forward_all (
345
314
self ,
346
315
hidden_states : torch .Tensor ,
347
- prompt_len : Optional [torch .Tensor ] = None ,
348
- ) -> torch .Tensor :
349
- assert prompt_len is None or prompt_len == hidden_states .shape [0 ], \
316
+ pooling_cursor : PoolingCursor ,
317
+ ) -> Union [list [torch .Tensor ], torch .Tensor ]:
318
+
319
+ assert not pooling_cursor .is_partial_prefill (), \
350
320
"partial prefill not supported with MEAN pooling"
351
321
352
- return hidden_states .mean (dim = 0 , dtype = torch .float32 )
322
+ prompt_lens = pooling_cursor .prompt_lens_cpu .to (hidden_states .device ,
323
+ non_blocking = True )
353
324
354
- def forward_all (
355
- self ,
356
- hidden_states : torch .Tensor ,
357
- prompt_lens : torch .Tensor ,
358
- ) -> Union [list [torch .Tensor ], torch .Tensor ]:
359
325
# Use float32 for torch.cumsum in MeanPool,
360
326
# otherwise precision will be lost significantly.
361
327
cumsum = torch .cumsum (hidden_states , dim = 0 , dtype = torch .float32 )
362
328
363
- start_indices = torch .cat ([
364
- torch .tensor ([0 ], device = hidden_states .device ),
365
- torch .cumsum (prompt_lens [:- 1 ], dim = 0 )
366
- ])
367
- end_indices = torch .cumsum (prompt_lens , dim = 0 )
368
- return (cumsum [end_indices - 1 ] - cumsum [start_indices ] +
329
+ start_indices = pooling_cursor .first_token_indices_gpu
330
+ end_indices = pooling_cursor .last_token_indices_gpu
331
+ return (cumsum [end_indices ] - cumsum [start_indices ] +
369
332
hidden_states [start_indices ]) / prompt_lens .unsqueeze (1 )
370
333
371
334
@@ -477,6 +440,10 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
477
440
478
441
pooling_params = get_pooling_params (pooling_metadata )
479
442
443
+ if isinstance (pooled_data , list ):
444
+ pooled_data = torch .stack (pooled_data )
445
+ # pooled_data shape: [batchsize, embedding_dimension]
446
+
480
447
# for matryoshka representation
481
448
dimensions_list = [
482
449
pooling_param .dimensions for pooling_param in pooling_params
@@ -667,6 +634,10 @@ def forward(
667
634
) -> PoolerOutput :
668
635
pooled_data = self .pooling (hidden_states , pooling_metadata )
669
636
637
+ if isinstance (pooled_data , list ):
638
+ pooled_data = torch .stack (pooled_data )
639
+ # pooled_data shape: [batchsize, hidden_size]
640
+
670
641
if self .classifier is not None :
671
642
# apply classifier once on the full batch if possible
672
643
if isinstance (pooled_data , torch .Tensor ):
@@ -717,12 +688,6 @@ def forward(
717
688
) -> PoolerOutput :
718
689
poolers_by_task = self .poolers_by_task
719
690
720
- if isinstance (hidden_states , list ):
721
- hidden_states_lst = hidden_states
722
- else :
723
- prompt_lens = get_prompt_lens (hidden_states , pooling_metadata )
724
- hidden_states_lst = list (hidden_states .split (prompt_lens .tolist ()))
725
-
726
691
outputs = list [PoolingSequenceGroupOutput ]()
727
692
offset = 0
728
693
for task , group in groupby (get_tasks (pooling_metadata )):
@@ -733,7 +698,7 @@ def forward(
733
698
734
699
num_items = len (list (group ))
735
700
group_output : PoolerOutput = pooler (
736
- hidden_states_lst [ offset : offset + num_items ] ,
701
+ hidden_states ,
737
702
pooling_metadata [offset :offset + num_items ],
738
703
)
739
704
0 commit comments