@@ -1722,7 +1722,7 @@ def execute_model(
17221722 # Compute prompt logprobs if needed.
17231723 prompt_logprobs_dict = self ._get_prompt_logprobs_dict (
17241724 hidden_states [:num_scheduled_tokens ],
1725- scheduler_output ,
1725+ scheduler_output . num_scheduled_tokens ,
17261726 )
17271727
17281728 # Get the valid generated tokens.
@@ -2064,7 +2064,7 @@ def save_tensorized_model(
20642064 def _get_prompt_logprobs_dict (
20652065 self ,
20662066 hidden_states : torch .Tensor ,
2067- scheduler_output : "SchedulerOutput" ,
2067+ num_scheduled_tokens : dict [ str , int ] ,
20682068 ) -> dict [str , Optional [LogprobsTensors ]]:
20692069 num_prompt_logprobs_dict = self .input_batch .num_prompt_logprobs
20702070 if not num_prompt_logprobs_dict :
@@ -2077,8 +2077,7 @@ def _get_prompt_logprobs_dict(
20772077 # maintainable loop over optimal performance.
20782078 completed_prefill_reqs = []
20792079 for req_id , num_prompt_logprobs in num_prompt_logprobs_dict .items ():
2080-
2081- num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
2080+ num_tokens = num_scheduled_tokens [req_id ]
20822081
20832082 # Get metadata for this request.
20842083 request = self .requests [req_id ]
0 commit comments