Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions python/restate/ext/adk/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,25 @@

from restate.extensions import current_context

from restate.ext.turnstile import Turnstile


def _create_turnstile(s: LlmResponse) -> Turnstile:
ids = _get_function_call_ids(s)
turnstile = Turnstile(ids)
return turnstile


class RestatePlugin(BasePlugin):
"""A plugin to integrate Restate with the ADK framework."""

_models: dict[str, BaseLlm]
_locks: dict[str, asyncio.Lock]
_turnstiles: dict[str, Turnstile | None]

def __init__(self, *, max_model_call_retries: int = 10):
super().__init__(name="restate_plugin")
self._models = {}
self._locks = {}
self._turnstiles = {}
self._max_model_call_retries = max_model_call_retries

async def before_agent_callback(
Expand All @@ -62,7 +70,7 @@ async def before_agent_callback(
)
model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model)
self._models[callback_context.invocation_id] = model
self._locks[callback_context.invocation_id] = asyncio.Lock()
self._turnstiles[callback_context.invocation_id] = None

id = callback_context.invocation_id
event = ctx.request().attempt_finished_event
Expand All @@ -73,7 +81,7 @@ async def release_task():
await event.wait()
finally:
self._models.pop(id, None)
self._locks.pop(id, None)
self._turnstiles.pop(id, None)

_ = asyncio.create_task(release_task())
return None
Expand All @@ -82,7 +90,7 @@ async def after_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> Optional[types.Content]:
self._models.pop(callback_context.invocation_id, None)
self._locks.pop(callback_context.invocation_id, None)
self._turnstiles.pop(callback_context.invocation_id, None)
return None

async def after_run_callback(self, *, invocation_context: InvocationContext) -> None:
Expand All @@ -100,6 +108,8 @@ async def before_model_callback(
"No Restate context found, the restate plugin must be used from within a restate handler."
)
response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request)
turnstile = _create_turnstile(response)
self._turnstiles[callback_context.invocation_id] = turnstile
return response

async def before_tool_callback(
Expand All @@ -109,11 +119,17 @@ async def before_tool_callback(
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
lock = self._locks[tool_context.invocation_id]
turnstile = self._turnstiles[tool_context.invocation_id]
assert turnstile is not None, "Turnstile not found for tool invocation."

id = tool_context.function_call_id
assert id is not None, "Function call ID is required for tool invocation."

await turnstile.wait_for(id)

ctx = current_context()
await lock.acquire()
tool_context.session.state["restate_context"] = ctx
# TODO: if we want we can also automatically wrap tools with ctx.run_typed here

return None

async def after_tool_callback(
Expand All @@ -125,8 +141,11 @@ async def after_tool_callback(
result: dict,
) -> Optional[dict]:
tool_context.session.state.pop("restate_context", None)
lock = self._locks[tool_context.invocation_id]
lock.release()
turnstile = self._turnstiles[tool_context.invocation_id]
assert turnstile is not None, "Turnstile not found for tool invocation."
id = tool_context.function_call_id
assert id is not None, "Function call ID is required for tool invocation."
turnstile.allow_next_after(id)
return None

async def on_tool_error_callback(
Expand All @@ -138,13 +157,26 @@ async def on_tool_error_callback(
error: Exception,
) -> Optional[dict]:
tool_context.session.state.pop("restate_context", None)
lock = self._locks[tool_context.invocation_id]
lock.release()
turnstile = self._turnstiles[tool_context.invocation_id]
assert turnstile is not None, "Turnstile not found for tool invocation."
id = tool_context.function_call_id
assert id is not None, "Function call ID is required for tool invocation."
turnstile.allow_next_after(id)
return None

async def close(self):
self._models.clear()
self._locks.clear()
self._turnstiles.clear()


def _get_function_call_ids(s: LlmResponse) -> list[str]:
ids = []
if s.content and s.content.parts:
for part in s.content.parts:
if part.function_call:
if part.function_call.id:
ids.append(part.function_call.id)
return ids


def _generate_client_function_call_id(s: LlmResponse) -> None:
Expand Down
4 changes: 3 additions & 1 deletion python/restate/ext/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

import typing

from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts
from restate import ObjectContext, Context
from restate.server_context import current_context

from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors
from .models import LlmRetryOpts


def restate_object_context() -> ObjectContext:
"""Get the current Restate ObjectContext."""
Expand Down
95 changes: 95 additions & 0 deletions python/restate/ext/openai/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

from typing import List, Any
import dataclasses

from agents import (
Handoff,
TContext,
Agent,
)

from agents.tool import FunctionTool, Tool
from agents.tool_context import ToolContext
from agents.items import TResponseOutputItem

from .models import State


def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]:
"""Extract function call IDs from the model response."""
# TODO: support function calls in other response types
return [item.call_id for item in response if item.type == "function_call"]


def _create_wrapper(state, captured_tool):
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
turnstile = state.turnstile
call_id = tool_context.tool_call_id
try:
await turnstile.wait_for(call_id)
return await captured_tool.on_invoke_tool(tool_context, tool_input)
finally:
turnstile.allow_next_after(call_id)

return on_invoke_tool_wrapper


def wrap_agent_tools(
agent: Agent[TContext],
state: State,
) -> Agent[TContext]:
"""
Wrap the tools of an agent to use the Restate error handling.

Returns:
A new agent with wrapped tools.
"""
wrapped_tools: list[Tool] = []
for tool in agent.tools:
if isinstance(tool, FunctionTool):
wrapped = _create_wrapper(state, tool)
wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=wrapped))
else:
wrapped_tools.append(tool)

wrapped_handoffs: list[Agent[Any] | Handoff[Any]] = []
for handoff in agent.handoffs:
if isinstance(handoff, Agent):
wrapped_handoff = wrap_agent_tools(handoff, state)
wrapped_handoffs.append(wrapped_handoff)
elif isinstance(handoff, Handoff):
wrapped_handoffs.append(wrap_agent_handoff_tools(handoff, state))
else:
raise TypeError(f"Unsupported handoff type: {type(handoff)}")

return agent.clone(tools=wrapped_tools, handoffs=wrapped_handoffs)


def wrap_agent_handoff_tools(
handoff: Handoff[TContext],
state: State,
) -> Handoff[TContext]:
"""
Wrap the tools of a handoff to use the Restate error handling.

Returns:
A new handoff with wrapped tools.
"""

original_on_invoke_handoff = handoff.on_invoke_handoff

async def wrapped(*args, **kwargs) -> Any:
agent = await original_on_invoke_handoff(*args, **kwargs)
return wrap_agent_tools(agent, state)

return dataclasses.replace(handoff, on_invoke_handoff=wrapped)
79 changes: 79 additions & 0 deletions python/restate/ext/openai/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
This module contains the optional OpenAI integration for Restate.
"""

import dataclasses

from agents import (
Usage,
)
from agents.items import TResponseOutputItem
from agents.items import TResponseInputItem
from datetime import timedelta
from typing import Optional
from pydantic import BaseModel

from restate.ext.turnstile import Turnstile


class State:
__slots__ = ("turnstile",)

def __init__(self) -> None:
self.turnstile = Turnstile([])


@dataclasses.dataclass
class LlmRetryOpts:
max_attempts: Optional[int] = 10
"""Max number of attempts (including the initial), before giving up.

When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
max_duration: Optional[timedelta] = None
"""Max duration of retries, before giving up.

When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
initial_retry_interval: Optional[timedelta] = timedelta(seconds=1)
"""Initial interval for the first retry attempt.
Retry interval will grow by a factor specified in `retry_interval_factor`.

If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
max_retry_interval: Optional[timedelta] = None
"""Max interval between retries.
Retry interval will grow by a factor specified in `retry_interval_factor`.

The default is 10 seconds."""
retry_interval_factor: Optional[float] = None
"""Exponentiation factor to use when computing the next retry delay.

If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""


# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
class RestateModelResponse(BaseModel):
output: list[TResponseOutputItem]
"""A list of outputs (messages, tool calls, etc) generated by the model"""

usage: Usage
"""The usage information for the response."""

response_id: str | None
"""An ID for the response which can be used to refer to the response in subsequent calls to the
model. Not supported by all model providers.
If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can
be passed to `Runner.run`.
"""

def to_input_items(self) -> list[TResponseInputItem]:
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore
Loading