Skip to content

Commit

Permalink
prompty: allow to override openai model for token count
Browse files Browse the repository at this point in the history
This is needed for AzureOpenAI models where the deployment name is not a valid OpenAI model.
  • Loading branch information
ianchi committed Oct 19, 2024
1 parent e6b576f commit 9f047c7
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/promptflow-core/promptflow/core/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,13 @@ def render(self, *args, **kwargs):
# For chat mode, the message generated is list type. Convert to string type and return to user.
return str(prompt)

def estimate_token_count(self, *args, **kwargs):
def estimate_token_count(self, model: Union[str, None] = None, *args, **kwargs):
"""Estimate the token count.
LLM will reject the request when prompt token + response token is greater than the maximum number of
tokens supported by the model. It is used to estimate the number of total tokens in this round of chat.
:param model: optional OpenAI model to use to determine tokenizer.
Use it when the Azure OpenaI deployment name is not a valid OpenAI model name.
:param args: positional arguments are not supported.
:param kwargs: prompty inputs with key word arguments.
:return: Estimate total token count
Expand All @@ -510,12 +512,12 @@ def estimate_token_count(self, *args, **kwargs):
elif response_max_token <= 1:
raise UserErrorException(f"{response_max_token} is less than the minimum of max_tokens.")

total_token = num_tokens_from_messages(prompt, self._model._model, working_dir=self.path.parent) + (
total_token = num_tokens_from_messages(prompt, model or self._model._model, working_dir=self.path.parent) + (
response_max_token or 0
)

if self._model.parameters.get("tools", None):
total_token += num_tokens_for_tools(self._model.parameters["tools"], self._model._model)
total_token += num_tokens_for_tools(self._model.parameters["tools"], model or self._model._model)
return total_token


Expand Down

0 comments on commit 9f047c7

Please sign in to comment.