File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -617,8 +617,10 @@ def extract_states(
617
617
self ,
618
618
hidden_states : Union [torch .Tensor , list [torch .Tensor ]],
619
619
pooling_metadata : PoolingMetadata ,
620
+ num_scheduled_tokens : torch .Tensor ,
620
621
) -> Union [list [torch .Tensor ], torch .Tensor ]:
621
- pooled_data_lst = self .pooling (hidden_states , pooling_metadata )
622
+ pooled_data_lst = self .pooling (hidden_states , pooling_metadata ,
623
+ num_scheduled_tokens )
622
624
prompt_token_ids = get_prompt_token_ids (pooling_metadata )
623
625
624
626
pooled_data = list [torch .Tensor ]()
@@ -652,7 +654,8 @@ def forward(
652
654
pooling_metadata : PoolingMetadata ,
653
655
num_scheduled_tokens : torch .Tensor ,
654
656
) -> PoolerOutput :
655
- pooled_data = self .extract_states (hidden_states , pooling_metadata )
657
+ pooled_data = self .extract_states (hidden_states , pooling_metadata ,
658
+ num_scheduled_tokens )
656
659
pooled_data = self .head (pooled_data , pooling_metadata )
657
660
return build_output (pooled_data )
658
661
Original file line number Diff line number Diff line change @@ -3324,10 +3324,14 @@ def _build_encoder_only_attn_metadata(
3324
3324
3325
3325
dummy_block_table = torch .zeros ((num_reqs , 1 ),
3326
3326
dtype = torch .int32 ,
3327
- device = self .device )
3327
+ pin_memory = self .pin_memory ,
3328
+ device = "cpu" ).to (self .device ,
3329
+ non_blocking = True )
3328
3330
dummy_slot_mapping = torch .zeros ((total_num_scheduled_tokens , ),
3329
3331
dtype = torch .int32 ,
3330
- device = self .device )
3332
+ pin_memory = self .pin_memory ,
3333
+ device = "cpu" ).to (self .device ,
3334
+ non_blocking = True )
3331
3335
3332
3336
group_metadata = dict [str , tuple [CommonAttentionMetadata , Any ]]()
3333
3337
You can’t perform that action at this time.
0 commit comments