Skip to content

Commit

Permalink
Support infer n parameter (modelscope#2893)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jan 9, 2025
1 parent 374ab66 commit a0d0351
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
8 changes: 8 additions & 0 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def _get_usage_info(num_prompt_tokens: int, num_generated_tokens: int) -> UsageI
total_tokens=num_prompt_tokens + num_generated_tokens,
)

@staticmethod
def _update_usage_info(origin_use_info: UsageInfo, num_generated_tokens: int) -> UsageInfo:
return UsageInfo(
prompt_tokens=origin_use_info.prompt_tokens,
completion_tokens=origin_use_info.completion_tokens + num_generated_tokens,
total_tokens=origin_use_info.total_tokens + num_generated_tokens,
)

@staticmethod
def _update_metrics(result, metrics: Optional[List[Metric]] = None):
if metrics is None:
Expand Down
50 changes: 29 additions & 21 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> _Generati
if request_config.logprobs:
generation_config.output_logits = True
generation_config.top_logprobs = request_config.top_logprobs
generation_config.num_return_sequences = request_config.n
return _GenerationConfig(**generation_config.to_dict())

def _add_stop_words(self, generation_config: _GenerationConfig, request_config: RequestConfig,
Expand Down Expand Up @@ -322,28 +323,35 @@ def _infer_full(self,
output.get('logits'), batched_generate_ids, generation_config.top_logprobs)

res = []
for i in range(batched_generate_ids.shape[0]):
generate_ids = batched_generate_ids[i]
num_return_sequences = generation_config.num_return_sequences
for i in range(inputs['attention_mask'].shape[0]):
choices = []
usage_info = self._get_usage_info(num_prompt_tokens, 0)
for j in range(num_return_sequences):
batched_index = i * num_return_sequences + j
generate_ids = batched_generate_ids[batched_index]

# ignore pad_token
masks = generate_ids != self.tokenizer.pad_token_id
generate_ids = generate_ids[masks].tolist()
logprobs_list = None
if batched_logprobs is not None:
logprobs_list = [logprobs for m, logprobs in zip(masks, batched_logprobs[i]) if m.item()]

logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids, generation_config.top_logprobs)
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
response = template.decode(generate_ids, template_inputs=template_inputs[i])
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True)
toolcall = self._get_toolcall(response, template.tools_prompt)
choices = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs)
]
# ignore pad_token
masks = generate_ids != self.tokenizer.pad_token_id
generate_ids = generate_ids[masks].tolist()
logprobs_list = None
if batched_logprobs is not None:
logprobs_list = [
logprobs for m, logprobs in zip(masks, batched_logprobs[batched_index]) if m.item()
]

logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids,
generation_config.top_logprobs)
usage_info = self._update_usage_info(usage_info, len(generate_ids))
response = template.decode(generate_ids, template_inputs=template_inputs[i])
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True)
toolcall = self._get_toolcall(response, template.tools_prompt)
choices.append(
ChatCompletionResponseChoice(
index=j,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs))
res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
return res

Expand Down

0 comments on commit a0d0351

Please sign in to comment.