File tree Expand file tree Collapse file tree 3 files changed +10
-2
lines changed Expand file tree Collapse file tree 3 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -196,6 +196,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
196
196
197
197
def build_output (
198
198
all_data : Union [torch .Tensor , list [torch .Tensor ]], ) -> PoolerOutput :
199
+ # Pooling models D2H occurs only here
200
+ if isinstance (all_data , list ):
201
+ all_data = [d .to ("cpu" , non_blocking = True ) for d in all_data ]
202
+ else :
203
+ all_data = all_data .to ("cpu" , non_blocking = True )
204
+ torch .cuda .synchronize ()
205
+
199
206
all_outputs = [PoolingSequenceGroupOutput (data ) for data in all_data ]
200
207
return PoolerOutput (outputs = all_outputs )
201
208
@@ -706,6 +713,7 @@ def forward(
706
713
hidden_states_lst = hidden_states
707
714
else :
708
715
prompt_lens = get_prompt_lens (hidden_states , pooling_metadata )
716
+
709
717
hidden_states_lst = list (hidden_states .split (prompt_lens .tolist ()))
710
718
711
719
outputs = list [PoolingSequenceGroupOutput ]()
Original file line number Diff line number Diff line change @@ -713,7 +713,7 @@ def pooling_metadata(self) -> PoolingMetadata:
713
713
714
714
return PoolingMetadata (
715
715
prompt_lens = torch .from_numpy (
716
- self .num_prompt_tokens [:self .num_reqs ]). to ( self . device ) ,
716
+ self .num_prompt_tokens [:self .num_reqs ]),
717
717
prompt_token_ids = self .sampling_metadata .prompt_token_ids ,
718
718
pooling_params = pooling_params ,
719
719
)
Original file line number Diff line number Diff line change @@ -1489,7 +1489,7 @@ def _pool(
1489
1489
raw_pooler_output , seq_lens , pooling_metadata .prompt_lens ):
1490
1490
1491
1491
if seq_len == prompt_len :
1492
- pooler_output .append (raw_output .data . cpu () )
1492
+ pooler_output .append (raw_output .data )
1493
1493
else :
1494
1494
pooler_output .append (None )
1495
1495
You can’t perform that action at this time.
0 commit comments