Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions tests/test_tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,60 @@

import verifiers as vf
from tests.conftest import faulty_tool, offset_tool, square_tool
from verifiers.utils.tool_utils import is_valid_tool_content_parts


class TestIsValidToolContentParts:
def test_valid_text_content_part(self):
"""Valid list with text content parts."""
content = [{"type": "text", "text": "Hello world"}]
assert is_valid_tool_content_parts(content) is True

def test_valid_image_url_content_part(self):
"""Valid list with image_url content parts."""
content = [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc123"}}
]
assert is_valid_tool_content_parts(content) is True

def test_valid_mixed_content_parts(self):
"""Valid list with mixed text and image_url content parts."""
content = [
{"type": "text", "text": "Here's the screenshot"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc123"}},
]
assert is_valid_tool_content_parts(content) is True

def test_empty_list_is_valid(self):
"""Empty list is valid (no invalid parts)."""
assert is_valid_tool_content_parts([]) is True

def test_invalid_type_value(self):
"""Content part with invalid type value should fail."""
content = [{"type": "invalid_type", "data": "some data"}]
assert is_valid_tool_content_parts(content) is False

def test_missing_type_key(self):
"""Content part without type key should fail."""
content = [{"text": "Hello world"}]
assert is_valid_tool_content_parts(content) is False

def test_non_dict_item_in_list(self):
"""Non-dict item in list should fail."""
content = ["just a string", {"type": "text", "text": "hello"}]
assert is_valid_tool_content_parts(content) is False

def test_non_list_input(self):
"""Non-list input should fail."""
assert is_valid_tool_content_parts("just a string") is False
assert is_valid_tool_content_parts({"type": "text", "text": "hi"}) is False
assert is_valid_tool_content_parts(42) is False
assert is_valid_tool_content_parts(None) is False

def test_list_of_primitives(self):
"""List of primitives should fail (not valid content parts)."""
assert is_valid_tool_content_parts([1, 2, 3]) is False
assert is_valid_tool_content_parts(["a", "b", "c"]) is False


def _build_tool_call(name: str, arguments: dict, tool_call_id: str = "call_0"):
Expand Down Expand Up @@ -250,3 +304,127 @@ def test_add_tool_updates_tool_monitor_rubric(

assert "offset_tool" in env.tool_monitor_rubric.tool_names
assert len(env.tool_monitor_rubric.tool_names) == 2

@pytest.mark.asyncio
async def test_call_tool_returns_valid_text_content_parts(
self, mock_openai_client, sample_chat_dataset
):
"""Test that call_tool preserves valid text content parts."""

def text_parts_tool() -> list:
return [{"type": "text", "text": "Hello world"}]

env = vf.ToolEnv(
tools=[text_parts_tool],
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
)

result = await env.call_tool("text_parts_tool", {}, "call_0")
assert result["content"] == [{"type": "text", "text": "Hello world"}]

@pytest.mark.asyncio
async def test_call_tool_returns_valid_image_url_content_parts(
self, mock_openai_client, sample_chat_dataset
):
"""Test that call_tool preserves valid image_url content parts."""

def image_tool() -> list:
return [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}
]

env = vf.ToolEnv(
tools=[image_tool],
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
)

result = await env.call_tool("image_tool", {}, "call_0")
assert result["content"] == [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}
]

@pytest.mark.asyncio
async def test_call_tool_returns_mixed_content_parts(
self, mock_openai_client, sample_chat_dataset
):
"""Test that call_tool preserves mixed valid content parts."""

def mixed_tool() -> list:
return [
{"type": "text", "text": "Here's the screenshot"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]

env = vf.ToolEnv(
tools=[mixed_tool],
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
)

result = await env.call_tool("mixed_tool", {}, "call_0")
assert result["content"] == [
{"type": "text", "text": "Here's the screenshot"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]

@pytest.mark.asyncio
async def test_call_tool_casts_invalid_list_to_str(
self, mock_openai_client, sample_chat_dataset
):
"""Test that call_tool casts invalid lists (not content parts) to str."""

def list_tool() -> list:
return [1, 2, 3]

env = vf.ToolEnv(
tools=[list_tool],
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
)

result = await env.call_tool("list_tool", {}, "call_0")
assert result["content"] == "[1, 2, 3]"

@pytest.mark.asyncio
async def test_call_tool_casts_list_missing_type_to_str(
self, mock_openai_client, sample_chat_dataset
):
"""Test that call_tool casts list with missing type keys to str."""

def bad_list_tool() -> list:
return [{"text": "no type key"}]

env = vf.ToolEnv(
tools=[bad_list_tool],
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
)

result = await env.call_tool("bad_list_tool", {}, "call_0")
assert result["content"] == "[{'text': 'no type key'}]"

@pytest.mark.asyncio
async def test_call_tool_casts_list_with_invalid_type_to_str(
self, mock_openai_client, sample_chat_dataset
):
"""Test that call_tool casts list with invalid type values to str."""

def invalid_type_tool() -> list:
return [{"type": "audio", "data": "base64data"}]

env = vf.ToolEnv(
tools=[invalid_type_tool],
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
)

result = await env.call_tool("invalid_type_tool", {}, "call_0")
assert result["content"] == "[{'type': 'audio', 'data': 'base64data'}]"
4 changes: 2 additions & 2 deletions verifiers/envs/tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import verifiers as vf
from verifiers.types import Messages
from verifiers.utils.async_utils import maybe_await
from verifiers.utils.tool_utils import convert_func_to_oai_tool
from verifiers.utils.tool_utils import convert_func_to_oai_tool, is_valid_tool_content_parts


class ToolMonitorRubric(vf.Rubric):
Expand Down Expand Up @@ -133,7 +133,7 @@ async def call_tool(
"""Call a tool based on JSON command."""
tool_func = self.tool_map[tool_name]
result = await maybe_await(tool_func, **tool_args)
content = result if isinstance(result, list) else str(result)
content = result if is_valid_tool_content_parts(result) else str(result)
return cast(
vf.Message,
{"role": "tool", "content": content, "tool_call_id": tool_call_id},
Expand Down
17 changes: 17 additions & 0 deletions verifiers/utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@
from agents.function_schema import function_schema
from openai.types.chat import ChatCompletionFunctionToolParam

VALID_TOOL_CONTENT_PART_TYPES = frozenset({"text", "image_url"})


def is_valid_tool_content_parts(value: Any) -> bool:
"""Check if value is a valid list of tool content parts.

Valid content parts have a "type" field with value "text" or "image_url".
"""
if not isinstance(value, list):
return False
for item in value:
if not isinstance(item, dict):
return False
if item.get("type") not in VALID_TOOL_CONTENT_PART_TYPES:
return False
return True


def convert_func_to_oai_tool(func: Any) -> ChatCompletionFunctionToolParam:
"""Convert *func* to an OpenAI function-calling tool schema.
Expand Down
Loading