Skip to content

Commit 3d1ffdb

Browse files
feat: unified tools= parameter for tool calling (#191)
Replace individual boolean parameters (web_search, x_search, code_execution) with a single tools= list parameter that accepts Tool instances, user-defined dicts, and raw passthrough dicts. - Add Tool, WebSearch, XSearch, CodeExecution classes with provider-specific ToolMappers that translate to wire format - Add ToolCall on Output and ToolResult(Message) for multi-turn tool use - Add ToolSupport constraint for model-level tool validation - Add _parse_tool_calls to all providers (Anthropic, OpenAI, xAI, Google, OpenResponses) for non-streaming tool call extraction - Add _aggregate_tool_calls to streaming for all providers - Update templates for new tools parameter pattern
1 parent 7d10721 commit 3d1ffdb

File tree

30 files changed

+720
-157
lines changed

30 files changed

+720
-157
lines changed

src/celeste/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
StrictJsonSchemaGenerator,
5252
StrictRefResolvingJsonSchemaGenerator,
5353
)
54+
from celeste.tools import CodeExecution, Tool, ToolCall, ToolResult, WebSearch, XSearch
5455
from celeste.types import Content, JsonValue, Message, Role
5556
from celeste.websocket import WebSocketClient, WebSocketConnection, close_all_ws_clients
5657

@@ -245,6 +246,7 @@ def create_client(
245246
"Authentication",
246247
"Capability",
247248
"ClientNotFoundError",
249+
"CodeExecution",
248250
"ConstraintViolationError",
249251
"Content",
250252
"Error",
@@ -270,14 +272,19 @@ def create_client(
270272
"StreamingNotSupportedError",
271273
"StrictJsonSchemaGenerator",
272274
"StrictRefResolvingJsonSchemaGenerator",
275+
"Tool",
276+
"ToolCall",
277+
"ToolResult",
273278
"UnsupportedCapabilityError",
274279
"UnsupportedParameterError",
275280
"UnsupportedProviderError",
276281
"Usage",
277282
"UsageField",
278283
"ValidationError",
284+
"WebSearch",
279285
"WebSocketClient",
280286
"WebSocketConnection",
287+
"XSearch",
281288
"audio",
282289
"close_all_http_clients",
283290
"close_all_ws_clients",

src/celeste/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from celeste.models import Model
1818
from celeste.parameters import ParameterMapper, Parameters
1919
from celeste.streaming import Stream, enrich_stream_errors
20+
from celeste.tools import ToolCall
2021
from celeste.types import RawUsage
2122

2223

@@ -200,13 +201,19 @@ async def _predict(
200201
)
201202
content = self._parse_content(response_data, **parameters)
202203
content = self._transform_output(content, **parameters)
204+
tool_calls = self._parse_tool_calls(response_data)
203205
return self._output_class()(
204206
content=content,
205207
usage=self._get_usage(response_data),
206208
finish_reason=self._get_finish_reason(response_data),
207209
metadata=self._build_metadata(response_data),
210+
tool_calls=tool_calls,
208211
)
209212

213+
def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
214+
"""Parse tool calls from response. Override in providers that support tools."""
215+
return []
216+
210217
def _stream(
211218
self,
212219
inputs: In,

src/celeste/constraints.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from abc import ABC, abstractmethod
66
from typing import Any, ClassVar, get_args, get_origin
77

8-
from pydantic import BaseModel, Field, computed_field
8+
from pydantic import BaseModel, Field, computed_field, field_serializer
99

1010
from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact
1111
from celeste.exceptions import ConstraintViolationError
1212
from celeste.mime_types import AudioMimeType, ImageMimeType, MimeType, VideoMimeType
13+
from celeste.tools import Tool
1314

1415

1516
class Constraint(BaseModel, ABC):
@@ -367,6 +368,26 @@ class AudiosConstraint(_MediaListConstraint[AudioMimeType]):
367368
_media_label = "audio"
368369

369370

371+
class ToolSupport(Constraint):
372+
"""Tool support constraint - validates Tool instances are supported by the model."""
373+
374+
tools: list[type[Tool]]
375+
376+
@field_serializer("tools")
377+
@classmethod
378+
def _serialize_tools(cls, v: list[type[Tool]]) -> list[str]:
379+
return [t.__name__ for t in v]
380+
381+
def __call__(self, value: list) -> list:
382+
"""Validate tools list against supported tools."""
383+
for item in value:
384+
if isinstance(item, Tool) and type(item) not in self.tools:
385+
supported = [t.__name__ for t in self.tools]
386+
msg = f"Tool '{type(item).__name__}' not supported. Supported: {supported}"
387+
raise ConstraintViolationError(msg)
388+
return value
389+
390+
370391
__all__ = [
371392
"AudioConstraint",
372393
"AudiosConstraint",
@@ -382,6 +403,7 @@ class AudiosConstraint(_MediaListConstraint[AudioMimeType]):
382403
"Range",
383404
"Schema",
384405
"Str",
406+
"ToolSupport",
385407
"VideoConstraint",
386408
"VideosConstraint",
387409
]

src/celeste/io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact
1010
from celeste.constraints import Constraint
1111
from celeste.core import InputType
12+
from celeste.tools import ToolCall
1213

1314

1415
class Input(BaseModel):
@@ -38,6 +39,7 @@ class Output[Content](BaseModel):
3839
usage: Usage = Field(default_factory=Usage)
3940
finish_reason: FinishReason | None = None
4041
metadata: dict[str, Any] = Field(default_factory=dict)
42+
tool_calls: list[ToolCall] = Field(default_factory=list)
4143

4244

4345
class Chunk[Content](BaseModel):

src/celeste/modalities/text/parameters.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import BaseModel
1010

1111
from celeste.parameters import Parameters
12+
from celeste.tools import ToolDefinition
1213

1314

1415
class TextParameter(StrEnum):
@@ -23,10 +24,8 @@ class TextParameter(StrEnum):
2324
THINKING_BUDGET = "thinking_budget"
2425
THINKING_LEVEL = "thinking_level"
2526
OUTPUT_SCHEMA = "output_schema"
26-
WEB_SEARCH = "web_search"
27+
TOOLS = "tools"
2728
VERBOSITY = "verbosity"
28-
X_SEARCH = "x_search"
29-
CODE_EXECUTION = "code_execution"
3029

3130
# Media input declarations (for optional_input_types)
3231
IMAGE = "image"
@@ -46,10 +45,8 @@ class TextParameters(Parameters):
4645
thinking_budget: int | str
4746
thinking_level: str
4847
output_schema: type[BaseModel]
49-
web_search: bool
48+
tools: list[ToolDefinition]
5049
verbosity: str
51-
x_search: bool
52-
code_execution: bool
5350

5451

5552
__all__ = [

src/celeste/modalities/text/providers/anthropic/client.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Anthropic text client (modality)."""
22

33
import base64
4+
import contextlib
45
from typing import Any, Unpack
56

67
from celeste.artifacts import ImageArtifact
@@ -10,6 +11,7 @@
1011
from celeste.providers.anthropic.messages.streaming import (
1112
AnthropicMessagesStream as _AnthropicMessagesStream,
1213
)
14+
from celeste.tools import ToolCall, ToolResult
1315
from celeste.types import ImageContent, Message, TextContent, VideoContent
1416
from celeste.utils import detect_mime_type
1517

@@ -30,15 +32,31 @@ class AnthropicTextStream(_AnthropicMessagesStream, TextStream):
3032
def __init__(self, *args: Any, **kwargs: Any) -> None:
3133
super().__init__(*args, **kwargs)
3234
self._message_start: dict[str, Any] | None = None
35+
self._tool_calls: dict[int, dict[str, Any]] = {}
3336

3437
def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None:
35-
"""Parse one SSE event into a typed chunk (captures message_start)."""
38+
"""Parse one SSE event into a typed chunk (captures message_start and tool_use)."""
3639
event_type = event_data.get("type")
3740
if event_type == "message_start":
3841
message = event_data.get("message")
3942
if isinstance(message, dict):
4043
self._message_start = message
4144
return None
45+
if event_type == "content_block_start":
46+
block = event_data.get("content_block", {})
47+
if block.get("type") == "tool_use":
48+
idx = event_data.get("index", len(self._tool_calls))
49+
self._tool_calls[idx] = {
50+
"id": block.get("id", ""),
51+
"name": block.get("name", ""),
52+
"input_json": "",
53+
}
54+
elif event_type == "content_block_delta":
55+
delta = event_data.get("delta", {})
56+
if delta.get("type") == "input_json_delta":
57+
idx = event_data.get("index", -1)
58+
if idx in self._tool_calls:
59+
self._tool_calls[idx]["input_json"] += delta.get("partial_json", "")
4260
return super()._parse_chunk(event_data)
4361

4462
def _aggregate_event_data(self, chunks: list[TextChunk]) -> list[dict[str, Any]]:
@@ -49,6 +67,21 @@ def _aggregate_event_data(self, chunks: list[TextChunk]) -> list[dict[str, Any]]
4967
events.extend(super()._aggregate_event_data(chunks))
5068
return events
5169

70+
def _aggregate_tool_calls(
71+
self, chunks: list[TextChunk], raw_events: list[dict[str, Any]]
72+
) -> list[ToolCall]:
73+
"""Reconstruct tool calls from accumulated content_block events."""
74+
import json as _json
75+
76+
result: list[ToolCall] = []
77+
for tc in self._tool_calls.values():
78+
arguments = {}
79+
if tc["input_json"]:
80+
with contextlib.suppress(ValueError, TypeError):
81+
arguments = _json.loads(tc["input_json"])
82+
result.append(ToolCall(id=tc["id"], name=tc["name"], arguments=arguments))
83+
return result
84+
5285

5386
class AnthropicTextClient(AnthropicMessagesClient, TextClient):
5487
"""Anthropic text client."""
@@ -86,10 +119,12 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
86119
if inputs.messages is not None:
87120
system_blocks: list[dict[str, Any]] = []
88121
messages: list[dict[str, Any]] = []
122+
pending_tool_results: list[dict[str, Any]] = []
89123

90124
for message in inputs.messages:
91125
role = message.role
92126
content = message.content
127+
93128
if role in {"system", "developer"}:
94129
if isinstance(content, list):
95130
for block in content:
@@ -105,8 +140,27 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
105140
system_blocks.append({"type": "text", "text": str(content)})
106141
continue
107142

143+
if isinstance(message, ToolResult):
144+
pending_tool_results.append(
145+
{
146+
"type": "tool_result",
147+
"tool_use_id": message.tool_call_id,
148+
"content": str(content),
149+
}
150+
)
151+
continue
152+
153+
# Flush pending tool results as a single user message
154+
if pending_tool_results:
155+
messages.append({"role": "user", "content": pending_tool_results})
156+
pending_tool_results = []
157+
108158
messages.append({"role": role, "content": content})
109159

160+
# Flush remaining tool results
161+
if pending_tool_results:
162+
messages.append({"role": "user", "content": pending_tool_results})
163+
110164
request: dict[str, Any] = {"messages": messages}
111165
if system_blocks:
112166
request["system"] = system_blocks
@@ -166,6 +220,16 @@ def _parse_content(
166220

167221
return text_content
168222

223+
def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
224+
"""Parse tool calls from Anthropic response."""
225+
return [
226+
ToolCall(
227+
id=block["id"], name=block["name"], arguments=block.get("input", {})
228+
)
229+
for block in response_data.get("content", [])
230+
if block.get("type") == "tool_use"
231+
]
232+
169233
def _stream_class(self) -> type[TextStream]:
170234
"""Return the Stream class for this provider."""
171235
return AnthropicTextStream

src/celeste/modalities/text/providers/anthropic/models.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Anthropic models for text modality."""
22

3-
from celeste.constraints import Bool, ImagesConstraint, Range, Schema
3+
from celeste.constraints import ImagesConstraint, Range, Schema, ToolSupport
44
from celeste.core import Modality, Operation, Parameter, Provider
55
from celeste.models import Model
6+
from celeste.tools import WebSearch
67

78
from ...parameters import TextParameter
89

@@ -17,7 +18,7 @@
1718
Parameter.MAX_TOKENS: Range(min=1, max=64000),
1819
TextParameter.THINKING_BUDGET: Range(min=-1, max=64000),
1920
TextParameter.OUTPUT_SCHEMA: Schema(),
20-
TextParameter.WEB_SEARCH: Bool(),
21+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
2122
TextParameter.IMAGE: ImagesConstraint(),
2223
},
2324
),
@@ -31,7 +32,7 @@
3132
Parameter.MAX_TOKENS: Range(min=1, max=64000),
3233
TextParameter.THINKING_BUDGET: Range(min=-1, max=32000),
3334
TextParameter.OUTPUT_SCHEMA: Schema(),
34-
TextParameter.WEB_SEARCH: Bool(),
35+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
3536
TextParameter.IMAGE: ImagesConstraint(),
3637
},
3738
),
@@ -45,7 +46,7 @@
4546
Parameter.MAX_TOKENS: Range(min=1, max=32000),
4647
TextParameter.THINKING_BUDGET: Range(min=-1, max=32000),
4748
TextParameter.OUTPUT_SCHEMA: Schema(),
48-
TextParameter.WEB_SEARCH: Bool(),
49+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
4950
TextParameter.IMAGE: ImagesConstraint(),
5051
},
5152
),
@@ -59,7 +60,7 @@
5960
Parameter.MAX_TOKENS: Range(min=1, max=64000),
6061
TextParameter.THINKING_BUDGET: Range(min=-1, max=32000),
6162
TextParameter.OUTPUT_SCHEMA: Schema(),
62-
TextParameter.WEB_SEARCH: Bool(),
63+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
6364
TextParameter.IMAGE: ImagesConstraint(),
6465
},
6566
),
@@ -73,7 +74,7 @@
7374
Parameter.MAX_TOKENS: Range(min=1, max=64000),
7475
TextParameter.THINKING_BUDGET: Range(min=-1, max=32000),
7576
TextParameter.OUTPUT_SCHEMA: Schema(),
76-
TextParameter.WEB_SEARCH: Bool(),
77+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
7778
TextParameter.IMAGE: ImagesConstraint(),
7879
},
7980
),
@@ -87,7 +88,7 @@
8788
Parameter.MAX_TOKENS: Range(min=1, max=64000),
8889
TextParameter.THINKING_BUDGET: Range(min=-1, max=64000),
8990
TextParameter.OUTPUT_SCHEMA: Schema(),
90-
TextParameter.WEB_SEARCH: Bool(),
91+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
9192
TextParameter.IMAGE: ImagesConstraint(),
9293
},
9394
),
@@ -100,7 +101,7 @@
100101
parameter_constraints={
101102
Parameter.MAX_TOKENS: Range(min=1, max=64000),
102103
TextParameter.THINKING_BUDGET: Range(min=-1, max=64000),
103-
TextParameter.WEB_SEARCH: Bool(),
104+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
104105
TextParameter.IMAGE: ImagesConstraint(),
105106
},
106107
),
@@ -113,7 +114,7 @@
113114
parameter_constraints={
114115
Parameter.MAX_TOKENS: Range(min=1, max=32000),
115116
TextParameter.THINKING_BUDGET: Range(min=-1, max=32000),
116-
TextParameter.WEB_SEARCH: Bool(),
117+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch]),
117118
TextParameter.IMAGE: ImagesConstraint(),
118119
},
119120
),

0 commit comments

Comments
 (0)