Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/republic/clients/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,19 @@ def _make_tool_context(prepared: PreparedChat, provider_name: str, model_id: str
def _extract_text(response: Any) -> str:
if isinstance(response, str):
return response
output = getattr(response, "output", None)
if output:
parts: list[str] = []
for item in output:
if getattr(item, "type", None) != "message":
continue
content = getattr(item, "content", None) or []
for entry in content:
if getattr(entry, "type", None) == "output_text":
text = getattr(entry, "text", None)
if text:
parts.append(text)
return "".join(parts)
choices = getattr(response, "choices", None)
if not choices:
return ""
Expand All @@ -1647,6 +1660,31 @@ def _extract_text(response: Any) -> str:

@staticmethod
def _extract_tool_calls(response: Any) -> list[dict[str, Any]]:
output = getattr(response, "output", None)
if output:
return ChatClient._extract_responses_tool_calls(output)
return ChatClient._extract_completion_tool_calls(response)

@staticmethod
def _extract_responses_tool_calls(output: list[Any]) -> list[dict[str, Any]]:
calls: list[dict[str, Any]] = []
for item in output:
if getattr(item, "type", None) != "function_call":
continue
name = getattr(item, "name", None)
arguments = getattr(item, "arguments", None)
if not name:
continue
entry: dict[str, Any] = {"function": {"name": name, "arguments": arguments or ""}}
call_id = getattr(item, "call_id", None) or getattr(item, "id", None)
if call_id:
entry["id"] = call_id
entry["type"] = "function"
calls.append(entry)
return calls

@staticmethod
def _extract_completion_tool_calls(response: Any) -> list[dict[str, Any]]:
choices = getattr(response, "choices", None)
if not choices:
return []
Expand Down
155 changes: 145 additions & 10 deletions src/republic/core/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
api_key: str | dict[str, str] | None,
api_base: str | dict[str, str] | None,
client_args: dict[str, Any],
use_responses: bool,
verbose: int,
error_classifier: Callable[[Exception], ErrorKind | None] | None = None,
) -> None:
Expand All @@ -71,6 +72,7 @@ def __init__(
self._api_key = api_key
self._api_base = api_base
self._client_args = client_args
self._use_responses = use_responses
self._verbose = verbose
self._error_classifier = error_classifier
self._client_cache: dict[str, AnyLLM] = {}
Expand Down Expand Up @@ -344,6 +346,133 @@ def _decide_kwargs_for_provider(
return kwargs
return {**kwargs, "max_completion_tokens": max_tokens}

def _decide_responses_kwargs(self, max_tokens: int | None, kwargs: dict[str, Any]) -> dict[str, Any]:
if "max_output_tokens" in kwargs:
return {k: v for k, v in kwargs.items() if k != "extra_headers"}
return {**{k: v for k, v in kwargs.items() if k != "extra_headers"}, "max_output_tokens": max_tokens}

def _should_use_responses(self, client: AnyLLM, *, stream: bool) -> bool:
return not stream and self._use_responses and bool(getattr(client, "SUPPORTS_RESPONSES", False))

def _call_client_sync(
self,
*,
client: AnyLLM,
provider_name: str,
model_id: str,
messages_payload: list[dict[str, Any]],
tools_payload: list[dict[str, Any]] | None,
max_tokens: int | None,
stream: bool,
reasoning_effort: Any | None,
kwargs: dict[str, Any],
) -> Any:
if self._should_use_responses(client, stream=stream):
instructions, input_items = self._split_messages_for_responses(messages_payload)
return client.responses(
model=model_id,
input_data=input_items,
tools=tools_payload,
stream=stream,
instructions=instructions,
**self._decide_responses_kwargs(max_tokens, kwargs),
)
return client.completion(
model=model_id,
messages=messages_payload,
tools=tools_payload,
stream=stream,
reasoning_effort=reasoning_effort,
**self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs),
)

async def _call_client_async(
self,
*,
client: AnyLLM,
provider_name: str,
model_id: str,
messages_payload: list[dict[str, Any]],
tools_payload: list[dict[str, Any]] | None,
max_tokens: int | None,
stream: bool,
reasoning_effort: Any | None,
kwargs: dict[str, Any],
) -> Any:
if self._should_use_responses(client, stream=stream):
instructions, input_items = self._split_messages_for_responses(messages_payload)
return await client.aresponses(
model=model_id,
input_data=input_items,
tools=tools_payload,
stream=stream,
instructions=instructions,
**self._decide_responses_kwargs(max_tokens, kwargs),
)
return await client.acompletion(
model=model_id,
messages=messages_payload,
tools=tools_payload,
stream=stream,
reasoning_effort=reasoning_effort,
**self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs),
)

@staticmethod
def _split_messages_for_responses(
messages: list[dict[str, Any]],
) -> tuple[str | None, list[dict[str, Any]]]:
instructions_parts: list[str] = []
filtered_messages: list[dict[str, Any]] = []
for message in messages:
role = message.get("role")
if role in {"system", "developer"}:
content = message.get("content")
if content not in (None, ""):
instructions_parts.append(str(content))
continue
filtered_messages.append(message)

instructions = "\n\n".join(part for part in instructions_parts if part.strip())
if not instructions:
instructions = None
return instructions, LLMCore._convert_messages_to_responses_input(filtered_messages)

@staticmethod
def _convert_messages_to_responses_input(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
input_items: list[dict[str, Any]] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role in {"user", "assistant"} and content not in (None, ""):
input_items.append({"role": role, "content": content, "type": "message"})

if role == "assistant":
tool_calls = message.get("tool_calls") or []
for index, tool_call in enumerate(tool_calls):
func = tool_call.get("function") or {}
name = func.get("name")
if not name:
continue
call_id = tool_call.get("id") or tool_call.get("call_id") or f"call_{index}"
input_items.append({
"type": "function_call",
"name": name,
"arguments": func.get("arguments", ""),
"call_id": call_id,
})

if role == "tool":
call_id = message.get("tool_call_id") or message.get("call_id")
if not call_id:
continue
input_items.append({
"type": "function_call_output",
"call_id": call_id,
"output": message.get("content", ""),
})
return input_items

def run_chat_sync(
self,
*,
Expand All @@ -365,13 +494,16 @@ def run_chat_sync(
last_provider, last_model = provider_name, model_id
for attempt in range(self.max_attempts()):
try:
response = client.completion(
model=model_id,
messages=messages_payload,
tools=tools_payload,
response = self._call_client_sync(
client=client,
provider_name=provider_name,
model_id=model_id,
messages_payload=messages_payload,
tools_payload=tools_payload,
max_tokens=max_tokens,
stream=stream,
reasoning_effort=reasoning_effort,
**self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs),
kwargs=kwargs,
)
except Exception as exc:
outcome = self._handle_attempt_error(exc, provider_name, model_id, attempt)
Expand Down Expand Up @@ -415,13 +547,16 @@ async def run_chat_async(
last_provider, last_model = provider_name, model_id
for attempt in range(self.max_attempts()):
try:
response = await client.acompletion(
model=model_id,
messages=messages_payload,
tools=tools_payload,
response = await self._call_client_async(
client=client,
provider_name=provider_name,
model_id=model_id,
messages_payload=messages_payload,
tools_payload=tools_payload,
max_tokens=max_tokens,
stream=stream,
reasoning_effort=reasoning_effort,
**self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs),
kwargs=kwargs,
)
except Exception as exc:
outcome = self._handle_attempt_error(exc, provider_name, model_id, attempt)
Expand Down
2 changes: 2 additions & 0 deletions src/republic/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
api_key: str | dict[str, str] | None = None,
api_base: str | dict[str, str] | None = None,
client_args: dict[str, Any] | None = None,
use_responses: bool = False,
verbose: int = 0,
tape_store: TapeStore | None = None,
context: TapeContext | None = None,
Expand All @@ -62,6 +63,7 @@ def __init__(
api_key=api_key,
api_base=api_base,
client_args=client_args or {},
use_responses=use_responses,
verbose=verbose,
error_classifier=error_classifier,
)
Expand Down
44 changes: 44 additions & 0 deletions tests/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def __init__(self, provider: str, kind: str) -> None:


class FakeAnyLLMClient:
SUPPORTS_RESPONSES = True

def __init__(self, provider: str) -> None:
self.provider = provider
self.calls: list[dict[str, Any]] = []
self.completion_queue: deque[Any] = deque()
self.acompletion_queue: deque[Any] = deque()
self.responses_queue: deque[Any] = deque()
self.aresponses_queue: deque[Any] = deque()
self.embedding_queue: deque[Any] = deque()
self.aembedding_queue: deque[Any] = deque()

Expand All @@ -25,6 +29,12 @@ def queue_completion(self, *items: Any) -> None:
def queue_acompletion(self, *items: Any) -> None:
self.acompletion_queue.extend(items)

def queue_responses(self, *items: Any) -> None:
self.responses_queue.extend(items)

def queue_aresponses(self, *items: Any) -> None:
self.aresponses_queue.extend(items)

def queue_embedding(self, *items: Any) -> None:
self.embedding_queue.extend(items)

Expand All @@ -48,6 +58,15 @@ async def acompletion(self, **kwargs: Any) -> Any:
queue = self.acompletion_queue if self.acompletion_queue else self.completion_queue
return self._next(queue, "acompletion")

def responses(self, **kwargs: Any) -> Any:
self.calls.append({"responses": True, **dict(kwargs)})
return self._next(self.responses_queue, "responses")

async def aresponses(self, **kwargs: Any) -> Any:
self.calls.append({"responses": True, **dict(kwargs)})
queue = self.aresponses_queue if self.aresponses_queue else self.responses_queue
return self._next(queue, "aresponses")

def _embedding(self, **kwargs: Any) -> Any:
self.calls.append({"embedding": True, **dict(kwargs)})
return self._next(self.embedding_queue, "embedding")
Expand Down Expand Up @@ -105,3 +124,28 @@ def make_chunk(
delta = SimpleNamespace(content=text, tool_calls=tool_calls or [])
choice = SimpleNamespace(delta=delta)
return SimpleNamespace(choices=[choice], usage=usage)


def make_responses_output_text(text: str) -> Any:
return SimpleNamespace(type="output_text", text=text)


def make_responses_message(text: str) -> Any:
return SimpleNamespace(type="message", content=[make_responses_output_text(text)])


def make_responses_function_call(name: str, arguments: str, call_id: str = "call_1") -> Any:
return SimpleNamespace(type="function_call", name=name, arguments=arguments, call_id=call_id)


def make_responses_response(
*,
text: str = "",
tool_calls: list[Any] | None = None,
) -> Any:
output: list[Any] = []
if text:
output.append(make_responses_message(text))
if tool_calls:
output.extend(tool_calls)
return SimpleNamespace(output=output)
61 changes: 61 additions & 0 deletions tests/test_responses_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from republic import LLM
from republic.clients.chat import ChatClient
from republic.core.execution import LLMCore

from .fakes import make_responses_function_call, make_responses_response


def test_llm_use_responses_calls_responses(fake_anyllm) -> None:
client = fake_anyllm.ensure("openai")
client.queue_responses(make_responses_response(text="hello"))

llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=True)
result = llm.chat("hi")

assert result == "hello"
assert client.calls[-1].get("responses") is True
assert client.calls[-1]["input_data"][0]["role"] == "user"


def test_extract_tool_calls_from_responses() -> None:
response = make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"hi"}')])

calls = ChatClient._extract_tool_calls(response)

assert calls == [
{
"function": {"name": "echo", "arguments": '{"text":"hi"}'},
"id": "call_1",
"type": "function",
}
]


def test_split_messages_for_responses() -> None:
messages = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "hi"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "echo", "arguments": '{"text":"hi"}'},
}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": '{"ok":true}'},
]

instructions, input_items = LLMCore._split_messages_for_responses(messages)

assert instructions == "sys"
assert input_items == [
{"role": "user", "content": "hi", "type": "message"},
{"type": "function_call", "name": "echo", "arguments": '{"text":"hi"}', "call_id": "call_1"},
{"type": "function_call_output", "call_id": "call_1", "output": '{"ok":true}'},
]