Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions tests/test_openai_chat_completions_token_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
from typing import Any, cast

import pytest

from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient
from verifiers.clients.openai_chat_completions_token_client import (
OpenAIChatCompletionsTokenClient,
)
from verifiers.types import State


class _NoopClient:
base_url = "http://localhost:8000/v1"

def with_options(self, **kwargs): # noqa: ANN003
return self


class _RecordingClient(_NoopClient):
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []

async def post(self, path: str, body: dict[str, Any], cast_to: type) -> Any:
self.calls.append({"path": path, "body": body, "cast_to": cast_to})
return {"ok": True, "path": path, "body": body}


class _PromptIdTestClient(OpenAIChatCompletionsTokenClient):
def __init__(self, full_prompt_ids: list[int]) -> None:
super().__init__(_NoopClient())
self._full_prompt_ids = full_prompt_ids

async def to_native_prompt(self, messages): # type: ignore[override]
return cast(Any, messages), {}

async def tokenize( # type: ignore[override]
self,
messages,
tools,
model,
extra_kwargs: dict = {},
**kwargs,
) -> list[int]:
if isinstance(messages, str):
assert messages == "World!"
return [777]

if messages == [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "World!"},
]:
assert extra_kwargs == {"add_generation_prompt": False}
return [1, 777, 999]

return self._full_prompt_ids


class _NoTokenizeClient(OpenAIChatCompletionsTokenClient):
def __init__(self) -> None:
super().__init__(_NoopClient())

async def to_native_prompt(self, messages): # type: ignore[override]
return cast(Any, messages), {}

async def tokenize( # type: ignore[override]
self,
messages,
tools,
model,
extra_kwargs: dict = {},
**kwargs,
) -> list[int]:
raise AssertionError("tokenize should not be called without a prefix match")


def _make_step(
prompt: list[dict[str, str]],
completion: list[dict[str, str]],
prompt_ids: list[int],
completion_ids: list[int],
) -> dict[str, Any]:
return {
"prompt": prompt,
"completion": completion,
"tokens": {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
},
}


@pytest.mark.asyncio
async def test_get_prompt_ids_uses_largest_message_prefix_match():
client = _PromptIdTestClient(full_prompt_ids=[1, 2, 3, 4, 999, 5])
state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
),
_make_step(
prompt=[
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
],
completion=[{"role": "assistant", "content": "a2"}],
prompt_ids=[1, 2, 3],
completion_ids=[4],
),
],
},
)
prompt_messages = cast(
Any,
[
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
{"role": "assistant", "content": "a2"},
{"role": "user", "content": "u3"},
],
)

prompt_ids = await client.get_prompt_ids(state, prompt_messages, oai_tools=None)

assert prompt_ids == [1, 2, 3, 4, 999, 5]


@pytest.mark.asyncio
async def test_get_prompt_ids_returns_none_when_no_prefix_match():
client = _NoTokenizeClient()
state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "old"}],
completion=[{"role": "assistant", "content": "reply"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)

prompt_ids = await client.get_prompt_ids(
state,
cast(Any, [{"role": "user", "content": "new"}]),
oai_tools=None,
)

assert prompt_ids is None


@pytest.mark.asyncio
async def test_get_native_response_falls_back_to_super_when_no_prefix_match(
monkeypatch: pytest.MonkeyPatch,
):
client = OpenAIChatCompletionsTokenClient(_NoopClient())
sentinel = {"source": "super"}
calls: list[dict[str, Any]] = []

async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa: ANN001
return None

async def fake_super_get_native_response( # noqa: ANN001
self,
prompt,
model,
sampling_args,
tools=None,
**kwargs,
):
calls.append(
{
"prompt": prompt,
"model": model,
"sampling_args": sampling_args,
"tools": tools,
}
)
return sentinel

monkeypatch.setattr(
OpenAIChatCompletionsTokenClient, "get_prompt_ids", fake_get_prompt_ids
)
monkeypatch.setattr(
OpenAIChatCompletionsClient,
"get_native_response",
fake_super_get_native_response,
)

state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)
prompt = cast(Any, [{"role": "user", "content": "u2"}])

response = await client.get_native_response(
prompt=prompt,
model="test-model",
sampling_args={},
tools=None,
state=state,
)

assert response is sentinel
assert len(calls) == 1
assert calls[0]["prompt"] == prompt


@pytest.mark.asyncio
async def test_get_native_response_uses_token_route_when_prompt_ids_available(
monkeypatch: pytest.MonkeyPatch,
):
recording_client = _RecordingClient()
client = OpenAIChatCompletionsTokenClient(recording_client)

async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa: ANN001
return [10, 20]

monkeypatch.setattr(
OpenAIChatCompletionsTokenClient, "get_prompt_ids", fake_get_prompt_ids
)

state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)
prompt = cast(Any, [{"role": "user", "content": "u2"}])

response = await client.get_native_response(
prompt=prompt,
model="test-model",
sampling_args={},
tools=None,
state=state,
)

assert response["ok"] is True
assert len(recording_client.calls) == 1
assert recording_client.calls[0]["path"] == "/chat/completions/tokens"
assert recording_client.calls[0]["body"]["tokens"] == [10, 20]
62 changes: 55 additions & 7 deletions verifiers/clients/openai_chat_completions_token_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, cast
from collections.abc import Mapping
from typing import Any, Optional, cast

from openai import AsyncOpenAI, BaseModel
from openai.types.chat import ChatCompletion
Expand Down Expand Up @@ -64,6 +65,10 @@ def normalize_sampling_args(sampling_args: SamplingArgs):
prompt, model, sampling_args, tools
)
prompt_ids = await self.get_prompt_ids(state, prompt, tools)
if prompt_ids is None:
return await super().get_native_response(
prompt, model, sampling_args, tools
)
extra_body = sampling_args.pop("extra_body", {})
body = dict(
model=model,
Expand All @@ -85,18 +90,61 @@ async def get_prompt_ids(
state: State,
prompt_messages: OpenAIChatMessages,
oai_tools: list[OpenAITool] | None,
) -> list[int]:
) -> list[int] | None:
"""
Build prompt_ids (token prompt) corresponding to prompt_messages. We assume
that this method is called *before* making the model response from
prompt_messages, i.e. the previous turn's prompt and completion do not yet
include the environment response and next turn's model response.

Returns None when no trajectory step has a message-level prefix match with
prompt_messages.
"""
prev_turn_tokens = state["trajectory"][-1]["tokens"]
assert prev_turn_tokens is not None
prev_turn_prompt_ids = prev_turn_tokens["prompt_ids"]
prev_turn_completion_ids = prev_turn_tokens["completion_ids"]
prev_turn_ids = prev_turn_prompt_ids + prev_turn_completion_ids

def normalize_for_comparison(value: Any) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make this a general message_util? seems useful in other places too? also, vaguely remember we have a similar util to this alr but might be wrong

if hasattr(value, "model_dump"):
return normalize_for_comparison(value.model_dump())
if isinstance(value, Mapping):
return {
str(key): normalize_for_comparison(val)
for key, val in value.items()
}
if isinstance(value, list):
return [normalize_for_comparison(item) for item in value]
return value

async def find_largest_prefix_match_tokens() -> list[int] | None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a small docstring here as well

normalized_prompt_messages = normalize_for_comparison(prompt_messages)
best_prefix_len = -1
best_step_tokens = None
for step in reversed(state["trajectory"]):
step_tokens = step["tokens"]
if step_tokens is None:
continue
step_messages = cast(Any, [*step["prompt"], *step["completion"]])
step_prompt_messages, _ = await self.to_native_prompt(step_messages)
normalized_step_messages = normalize_for_comparison(
step_prompt_messages
)
prefix_len = len(normalized_step_messages)
if prefix_len <= best_prefix_len:
continue
if prefix_len > len(normalized_prompt_messages):
continue
if normalized_prompt_messages[:prefix_len] != normalized_step_messages:
continue
best_prefix_len = prefix_len
best_step_tokens = step_tokens
if best_prefix_len == len(normalized_prompt_messages):
break

if best_step_tokens is None:
return None
return best_step_tokens["prompt_ids"] + best_step_tokens["completion_ids"]

prev_turn_ids = await find_largest_prefix_match_tokens()
if prev_turn_ids is None:
return None

def compute_suffix_ids(lst: list[int], value: int) -> list[int]:
"""Returns all tokens after the last occurrence of `value` in `lst`, if any."""
Expand Down
Loading