Skip to content

Commit

Permalink
[Feature] Metrics for Grammarless models (#817)
Browse files Browse the repository at this point in the history
Start working on incorporating the token counting metrics to `Grammarless` (i.e. remote, non-guidance-aware) models.
  • Loading branch information
riedgar-ms authored May 14, 2024
1 parent 1e02ffc commit 13270bf
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 14 deletions.
20 changes: 15 additions & 5 deletions guidance/models/_azureai_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _generator(self, prompt, temperature: float):

# find the role tags
pos = 0
input_token_count = 0
role_end = b"<|im_end|>"
messages = []
found = True
Expand All @@ -95,9 +96,9 @@ def _generator(self, prompt, temperature: float):
break
btext = prompt[pos : pos + end_pos]
pos += end_pos + len(role_end)
messages.append(
{"role": role_name, "content": btext.decode("utf8")}
)
message_content = btext.decode("utf8")
input_token_count += len(self.tokenizer(message_content))
messages.append({"role": role_name, "content": message_content})
found = True
break

Expand Down Expand Up @@ -137,7 +138,13 @@ def _generator(self, prompt, temperature: float):
)

result = response.choices[0]
encoded_chunk = result.message.content.encode("utf8") # type: ignore[union-attr]
chunk = result.message.content
encoded_chunk = chunk.encode("utf8") # type: ignore[union-attr]

# Non-streaming OpenAI call, so we can just get the metrics directly
if response.usage is not None:
self.metrics.engine_input_tokens += response.usage.prompt_tokens
self.metrics.engine_output_tokens += response.usage.completion_tokens
else:
parameters = dict(temperature=temperature)
payload = dict(
Expand All @@ -157,7 +164,10 @@ def _generator(self, prompt, temperature: float):

result_score = response_score.json()

encoded_chunk = result_score["output"].encode("utf8")
chunk = result_score["output"]
encoded_chunk = chunk.encode("utf8")
self.metrics.engine_input_tokens += input_token_count
self.metrics.engine_output_tokens += len(self.tokenizer(chunk))

# Now back to OpenAIChatEngine, with slight modifications since
# this isn't a streaming API
Expand Down
6 changes: 6 additions & 0 deletions guidance/models/_grammarless.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,14 @@ def __init__(self, tokenizer):
else:
raise Exception("The tokenizer given was not of a recognized type!")

self._orig_tokenizer = tokenizer

super().__init__(byte_tokens, bos_token_id, eos_token_id)

def __call__(self, byte_string):
"""Returns a list of tokens that represent the given byte string."""
return self._orig_tokenizer.encode(byte_string)


class GrammarlessEngine(Engine):
def __init__(self, tokenizer, max_streaming_tokens, timeout, compute_log_probs):
Expand Down
27 changes: 19 additions & 8 deletions guidance/models/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,9 @@ def __init__(
# instruct
# elif "instruct" in model: # All current OpenAI instruct models behave as Completion models.
# found_subclass = OpenAIInstruct

found_subclass: typing.Type[OpenAI] = (
OpenAICompletion
if model.endswith("-instruct")
else OpenAIChat
OpenAICompletion if model.endswith("-instruct") else OpenAIChat
)

# convert to any found subclass
Expand Down Expand Up @@ -149,9 +147,13 @@ def _generator(self, prompt, temperature):
self._reset_shared_data(prompt, temperature) # update our shared data state

try:
# Ideally, for the metrics we would use those returned by the
# OpenAI API. Unfortunately, it appears that AzureAI hosted
# models do not support returning metrics when streaming yet
prompt_string = prompt.decode("utf8")
generator = self.client.completions.create(
model=self.model_name,
prompt=prompt.decode("utf8"),
prompt=prompt_string,
max_tokens=self.max_streaming_tokens,
n=1,
top_p=1.0, # TODO: this should be controllable like temp (from the grammar)
Expand All @@ -166,6 +168,8 @@ def _generator(self, prompt, temperature):
chunk = part.choices[0].text or ""
else:
chunk = ""
self.metrics.engine_input_tokens += len(self.tokenizer(prompt_string))
self.metrics.engine_output_tokens += len(self.tokenizer(chunk))
yield chunk.encode("utf8")


Expand Down Expand Up @@ -212,6 +216,7 @@ def _generator(self, prompt, temperature):
chunk = part.choices[0].text or ""
else:
chunk = ""

yield chunk.encode("utf8")


Expand All @@ -235,6 +240,7 @@ def _generator(self, prompt, temperature):
role_end = b"<|im_end|>"
messages = []
found = True
input_token_count = 0
while found:

# find the role text blocks
Expand All @@ -254,9 +260,9 @@ def _generator(self, prompt, temperature):
break
btext = prompt[pos : pos + end_pos]
pos += end_pos + len(role_end)
messages.append(
{"role": role_name, "content": btext.decode("utf8")}
)
message_content = btext.decode("utf8")
input_token_count += len(self.tokenizer(message_content))
messages.append({"role": role_name, "content": message_content})
found = True
break

Expand Down Expand Up @@ -284,6 +290,9 @@ def _generator(self, prompt, temperature):

# API call and response handling
try:
# Ideally, for the metrics we would use those returned by the
# OpenAI API. Unfortunately, it appears that AzureAI hosted
# models do not support returning metrics when streaming yet
generator = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
Expand All @@ -293,6 +302,7 @@ def _generator(self, prompt, temperature):
temperature=temperature,
stream=True,
)
self.metrics.engine_input_tokens += input_token_count

if temperature == 0:
cached_results = []
Expand All @@ -303,6 +313,7 @@ def _generator(self, prompt, temperature):
else:
chunk = ""
encoded_chunk = chunk.encode("utf8")
self.metrics.engine_output_tokens += len(self.tokenizer(chunk))
yield encoded_chunk

if temperature == 0:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/common_chat_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


def smoke_chat(lm: models.Chat, has_system_role: bool = True):
lm.engine.reset_metrics()
if has_system_role:
with system():
lm += "You are a math wiz."
Expand All @@ -14,8 +15,11 @@ def smoke_chat(lm: models.Chat, has_system_role: bool = True):
lm += "Pick a number: "

print(str(lm))
print(f"{lm.engine.metrics=}")
assert len(lm["text"]) > 0
assert str(lm).endswith("Pick a number: <|im_end|>")
assert lm.engine.metrics.engine_input_tokens > 2, "Expect some input tokens"
assert lm.engine.metrics.engine_output_tokens > 0, "Expect some output tokens"


def longer_chat_1(lm: models.Chat, has_system_role: bool = True):
Expand Down
7 changes: 6 additions & 1 deletion tests/models/test_azureai_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
pytestmark = pytest.mark.needs_credentials



def test_azureai_openai_chat_smoke(rate_limiter):
azureai_endpoint = env_or_fail("AZUREAI_CHAT_ENDPOINT")
azureai_key = env_or_fail("AZUREAI_CHAT_KEY")
Expand Down Expand Up @@ -86,10 +85,13 @@ def test_azureai_openai_completion_smoke(rate_limiter):
model=model, azure_endpoint=azureai_endpoint, api_key=azureai_key
)
assert isinstance(lm, models.AzureOpenAICompletion)
assert isinstance(lm.engine, models._openai.OpenAICompletionEngine)

result = lm + "What is 2+2?" + gen(max_tokens=10, name="text")
print(f"result: {result['text']}")
assert len(result["text"]) > 0
assert lm.engine.metrics.engine_input_tokens > 0
assert lm.engine.metrics.engine_output_tokens > 0


def test_azureai_openai_completion_alt_args(rate_limiter):
Expand All @@ -111,10 +113,13 @@ def test_azureai_openai_completion_alt_args(rate_limiter):
azure_deployment=azureai_deployment,
)
assert isinstance(lm, models.AzureOpenAICompletion)
assert isinstance(lm.engine, models._openai.OpenAICompletionEngine)

result = lm + "What is 2+2?" + gen(max_tokens=10, name="text")
print(f"result: {result['text']}")
assert len(result["text"]) > 0
assert lm.engine.metrics.engine_input_tokens > 0
assert lm.engine.metrics.engine_output_tokens > 0


def test_azureai_openai_chat_loop(rate_limiter):
Expand Down

0 comments on commit 13270bf

Please sign in to comment.