Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions tests/test_client_multimodal_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import base64

import pytest
import numpy as np
from types import SimpleNamespace

from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient
Expand Down Expand Up @@ -237,3 +240,115 @@ async def test_anthropic_tool_call_round_trips_thinking_blocks():
{"type": "thinking", "thinking": "hidden chain", "signature": "sig_1"},
{"type": "tool_use", "id": "call_1", "name": "lookup", "input": {"q": "x"}},
]


class _CaptureAnthropicMessages:
def __init__(self) -> None:
self.last_kwargs: dict | None = None

async def create(self, **kwargs):
self.last_kwargs = kwargs
return SimpleNamespace()


class _CaptureAnthropicClient:
def __init__(self) -> None:
self.messages = _CaptureAnthropicMessages()


@pytest.mark.asyncio
async def test_anthropic_get_native_response_forwards_router_replay_with_extra_body():
pytest.importorskip("anthropic")
from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient

native_client = _CaptureAnthropicClient()
client = AnthropicMessagesClient(native_client)

await client.get_native_response(
prompt=[{"role": "user", "content": "hello"}],
model="claude-test",
sampling_args={
"max_tokens": 32,
"temperature": 0.2,
"extra_body": {"seed": 7},
"routed_experts": [[[1]]],
},
)

sent = native_client.messages.last_kwargs
assert sent is not None
assert sent["temperature"] == 0.2
assert sent["extra_body"] == {"seed": 7, "routed_experts": [[[1]]]}
assert "routed_experts" not in sent


@pytest.mark.asyncio
async def test_anthropic_get_native_response_defaults_max_tokens_when_missing():
pytest.importorskip("anthropic")
from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient

native_client = _CaptureAnthropicClient()
client = AnthropicMessagesClient(native_client)

await client.get_native_response(
prompt=[{"role": "user", "content": "hello"}],
model="claude-test",
sampling_args={"temperature": 0.2},
)

sent = native_client.messages.last_kwargs
assert sent is not None
assert sent["max_tokens"] == 32768
assert sent["temperature"] == 0.2


@pytest.mark.asyncio
async def test_anthropic_from_native_response_extracts_tokens_and_router_replay():
pytest.importorskip("anthropic")
from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient

client = AnthropicMessagesClient(object())
routed = np.array([[[11, 12]], [[21, 22]]], dtype=np.int32)
native_response = SimpleNamespace(
id="msg_tokens",
model="claude-haiku-4-5",
stop_reason="end_turn",
usage=SimpleNamespace(input_tokens=3, output_tokens=2),
content=[SimpleNamespace(type="text", text="ok")],
prompt_token_ids=[1, 2, 3],
token_ids=[4, 5],
logprobs={"content": [{"logprob": -0.1}, {"logprob": -0.2}]},
routed_experts={
"data": base64.b85encode(routed.tobytes()).decode("utf-8"),
"shape": list(routed.shape),
},
)

response = await client.from_native_response(native_response)

assert response.message.tokens is not None
assert response.message.tokens.prompt_ids == [1, 2, 3]
assert response.message.tokens.completion_ids == [4, 5]
assert response.message.tokens.completion_logprobs == [-0.1, -0.2]
assert response.message.tokens.routed_experts == routed.tolist()


@pytest.mark.asyncio
async def test_anthropic_from_native_response_requires_logprobs_for_tokens():
pytest.importorskip("anthropic")
from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient

client = AnthropicMessagesClient(object())
native_response = SimpleNamespace(
id="msg_tokens_missing",
model="claude-haiku-4-5",
stop_reason="end_turn",
usage=SimpleNamespace(input_tokens=2, output_tokens=1),
content=[SimpleNamespace(type="text", text="ok")],
prompt_token_ids=[1, 2],
token_ids=[3],
logprobs=None,
)

response = await client.from_native_response(native_response)
assert response.message.tokens is None
125 changes: 119 additions & 6 deletions verifiers/clients/anthropic_messages_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import base64
import functools
import json
import time
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from typing import Any, cast

import numpy as np

from anthropic import (
AsyncAnthropic,
AuthenticationError,
Expand Down Expand Up @@ -38,6 +41,7 @@
Messages,
Response,
ResponseMessage,
ResponseTokens,
SamplingArgs,
SystemMessage,
TextMessage,
Expand All @@ -49,6 +53,10 @@
)
from verifiers.utils.client_utils import setup_anthropic_client

# Default output-token limit used when callers omit max_tokens.
# Anthropic /v1/messages requires max_tokens on every request.
ANTHROPIC_MAX_TOKENS: int = 32768


def _handle_anthropic_overlong_prompt(func):
"""Decorator to handle overlong prompt errors from the Anthropic API."""
Expand Down Expand Up @@ -87,6 +95,12 @@ class AnthropicMessagesClient(
"""Wrapper for Messages API via AsyncAnthropic client."""

def setup_client(self, config: ClientConfig) -> AsyncAnthropic:
# Log the default and remind that max_tokens is required for Anthropic.
self.logger.info(
"Anthropic client initialized. max_tokens is required on every request; "
"defaulting to ANTHROPIC_MAX_TOKENS=%d when not provided.",
ANTHROPIC_MAX_TOKENS,
)
return setup_anthropic_client(config)

async def close(self) -> None:
Expand Down Expand Up @@ -345,13 +359,37 @@ def normalize_sampling_args(sampling_args: SamplingArgs) -> dict:
max_tokens = sampling_args.pop("max_tokens", None)
sampling_args.pop("n", None)
sampling_args.pop("stop", None)
if max_tokens is None:
self.logger.warning(
"max_tokens is not set but Anthropic /v1/messages endpoint requires it, falling back to max_tokens=4096"
extra_body = sampling_args.pop("extra_body", {})
if not isinstance(extra_body, Mapping):
raise TypeError(
"sampling_args['extra_body'] must be a mapping when provided"
)
max_tokens = 4096
if max_tokens is None:
# Anthropic /v1/messages requires max_tokens to be set in every request.
max_tokens = ANTHROPIC_MAX_TOKENS
sampling_args["max_tokens"] = max_tokens

# Anthropic SDK validates top-level request fields.
# Forward unknown model args through extra_body
# so backend-specific payloads (e.g. routed_experts) can be passed via
# sampling_args without custom provider branching.
known_anthropic_args = {
"max_tokens",
"metadata",
"service_tier",
"stop_sequences",
"temperature",
"thinking",
"top_k",
"top_p",
}
extra_body_dict: dict[str, Any] = dict(extra_body)
for key in list(sampling_args.keys()):
if key not in known_anthropic_args:
extra_body_dict[key] = sampling_args.pop(key)
if extra_body_dict:
sampling_args["extra_body"] = extra_body_dict

return {k: v for k, v in sampling_args.items() if v is not None}

# Remove internal framework keys not recognized by the Anthropic SDK
Expand Down Expand Up @@ -440,6 +478,81 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason:
case _:
return None

def parse_completion_logprobs(logprobs: Any) -> list[float] | None:
if isinstance(logprobs, Mapping):
content = logprobs.get("content")
else:
content = getattr(logprobs, "content", None)
if content is None:
return None
if isinstance(content, Mapping):
content_items: Iterable[Any] = [content]
elif isinstance(content, list):
content_items = content
elif isinstance(content, Iterable) and not isinstance(
content, (str, bytes)
):
content_items = list(content)
else:
return None
values: list[float] = []
for token in content_items:
if isinstance(token, Mapping):
value = token.get("logprob")
else:
value = getattr(token, "logprob", None)
if not isinstance(value, (float, int)):
return None
values.append(float(value))
return values

def parse_tokens(response: AnthropicMessage) -> ResponseTokens | None:
prompt_ids = getattr(response, "prompt_token_ids", None)
completion_ids = getattr(response, "token_ids", None)
if not isinstance(prompt_ids, list) or not isinstance(completion_ids, list):
return None
if not all(isinstance(token_id, int) for token_id in prompt_ids):
return None
if not all(isinstance(token_id, int) for token_id in completion_ids):
return None

completion_logprobs = parse_completion_logprobs(
getattr(response, "logprobs", None)
)
if completion_logprobs is None:
return None

has_routed_experts = (
isinstance(
routed_experts := getattr(response, "routed_experts", None), dict
)
and "data" in routed_experts
and "shape" in routed_experts
)
if has_routed_experts:
routed_experts = cast(dict[str, Any], routed_experts)
routed_experts = cast(
list[list[list[int]]],
(
np.frombuffer(
base64.b85decode(routed_experts["data"]), dtype=np.int32
)
.reshape(routed_experts["shape"])
.tolist()
),
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Malformed routed_experts crashes entire response parsing

Medium Severity

The parse_tokens function gracefully returns None when prompt_token_ids, token_ids, or logprobs are missing or invalid, but the routed_experts decoding block (base64.b85decode + np.frombuffer + .reshape) has no try-except. If the server returns a routed_experts dict with valid data and shape keys but malformed content (e.g. corrupt base85 or shape mismatch), an unhandled ValueError propagates out of from_native_response, causing the entire response — including valid text content — to be lost as a ModelError. Wrapping the decode in a try-except and falling back to routed_experts = None would be consistent with the rest of the function's defensive design.

Fix in Cursor Fix in Web

else:
routed_experts = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Routed-experts decoding duplicated across two client files

Low Severity

The routed_experts decoding block — walrus-operator check, base64.b85decode, np.frombuffer with int32, .reshape, .tolist() — is copy-pasted verbatim from openai_chat_completions_client.py into anthropic_messages_client.py. Extracting this into a shared utility (e.g. in client_utils) would eliminate the duplication and ensure future fixes apply to both paths.

Fix in Cursor Fix in Web


return ResponseTokens(
prompt_ids=prompt_ids,
prompt_mask=[0] * len(prompt_ids),
completion_ids=completion_ids,
completion_mask=[1] * len(completion_ids),
completion_logprobs=completion_logprobs,
routed_experts=routed_experts,
)

content, reasoning_content, tool_calls, thinking_blocks = parse_content(
response.content
)
Expand All @@ -465,6 +578,6 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason:
tool_calls=tool_calls or None,
finish_reason=parse_finish_reason(response),
is_truncated=is_truncated,
tokens=None,
tokens=parse_tokens(response),
),
)
Loading