Skip to content
Merged
28 changes: 26 additions & 2 deletions temporalio/contrib/openai_agents/_model_parameters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
"""Parameters for configuring Temporal activity execution for model calls."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
from typing import Any, Callable, Optional, Union

from agents import Agent, TResponseInputItem

from temporalio.common import Priority, RetryPolicy
from temporalio.workflow import ActivityCancellationType, VersioningIntent


class ModelSummaryProvider(ABC):
"""Abstract base class for providing model summaries. Essentially just a callable,
but the arguments are sufficiently complex to benefit from names.
"""

@abstractmethod
def provide(
self,
agent: Optional[Agent[Any]],
instructions: Optional[str],
input: Union[str, list[TResponseInputItem]],
) -> str:
"""Given the provided information, produce a summary for the model invocation activity."""
pass


@dataclass
class ModelActivityParameters:
"""Parameters for configuring Temporal activity execution for model calls.
Expand Down Expand Up @@ -41,7 +60,12 @@ class ModelActivityParameters:
versioning_intent: Optional[VersioningIntent] = None
"""Versioning intent for the activity."""

summary_override: Optional[str] = None
summary_override: Optional[
Union[
str,
ModelSummaryProvider,
]
] = None
"""Summary for the activity execution."""

priority: Priority = Priority.default
Expand Down
85 changes: 70 additions & 15 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import dataclasses
import json
import typing
from dataclasses import replace
from typing import Any, Union
from typing import Any, Optional, Union

from agents import (
Agent,
Handoff,
RunConfig,
RunContextWrapper,
RunResult,
RunResultStreaming,
SQLiteSession,
Expand Down Expand Up @@ -77,26 +79,70 @@ async def run(
if run_config is None:
run_config = RunConfig()

model_name = run_config.model or starting_agent.model
if model_name is not None and not isinstance(model_name, str):
raise ValueError(
"Temporal workflows require a model name to be a string in the run config and/or agent."
if run_config.model:
if not isinstance(run_config.model, str):
raise ValueError(
"Temporal workflows require a model name to be a string in the run config."
)
run_config = dataclasses.replace(
run_config,
model=_TemporalModelStub(
run_config.model, model_params=self.model_params, agent=None
),
)

# Recursively replace models in all agents
def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]:
if seen is None:
seen = set()

# Short circuit if this model was already seen to prevent looping from circular handoffs
if id(agent) in seen:
return agent
seen.add(id(agent))

# This agent has already been processed in some other run
if isinstance(agent.model, _TemporalModelStub):
return agent

name = _model_name(agent)

new_handoffs: list[Union[Agent, Handoff]] = []
for handoff in agent.handoffs:
if isinstance(handoff, Agent):
new_handoffs.append(convert_agent(handoff, seen))
elif isinstance(handoff, Handoff):
original_invoke = handoff.on_invoke_handoff

async def on_invoke(
context: RunContextWrapper[Any], args: str
) -> Agent:
handoff_agent = await original_invoke(context, args)
return convert_agent(handoff_agent, seen)

new_handoffs.append(
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
)
else:
raise ValueError(f"Unknown handoff type: {type(handoff)}")

return dataclasses.replace(
agent,
model=_TemporalModelStub(
model_name=name,
model_params=self.model_params,
agent=agent,
),
handoffs=new_handoffs,
)
updated_run_config = replace(
run_config,
model=_TemporalModelStub(
model_name=model_name,
model_params=self.model_params,
),
)

return await self._runner.run(
starting_agent=starting_agent,
starting_agent=convert_agent(starting_agent, None),
input=input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=updated_run_config,
run_config=run_config,
previous_response_id=previous_response_id,
session=session,
)
Expand Down Expand Up @@ -130,3 +176,12 @@ def run_streamed(
**kwargs,
)
raise RuntimeError("Temporal workflows do not support streaming.")


def _model_name(agent: Agent[Any]) -> Optional[str]:
name = agent.model
if name is not None and not isinstance(name, str):
raise ValueError(
"Temporal workflows require a model name to be a string in the agent."
)
return name
22 changes: 20 additions & 2 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, AsyncIterator, Union, cast

from agents import (
Agent,
AgentOutputSchema,
AgentOutputSchemaBase,
CodeInterpreterTool,
Expand Down Expand Up @@ -50,9 +51,11 @@ def __init__(
model_name: Optional[str],
*,
model_params: ModelActivityParameters,
agent: Optional[Agent[Any]],
) -> None:
self.model_name = model_name
self.model_params = model_params
self.agent = agent

async def get_response(
self,
Expand Down Expand Up @@ -124,7 +127,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
activity_input = ActivityModelInput(
model_name=self.model_name,
system_instructions=system_instructions,
input=cast(Union[str, list[TResponseInputItem]], input),
input=input,
model_settings=model_settings,
tools=tool_infos,
output_schema=output_schema_input,
Expand All @@ -134,10 +137,25 @@ def make_tool_info(tool: Tool) -> ToolInput:
prompt=prompt,
)

if self.model_params.summary_override:
summary = (
self.model_params.summary_override
if isinstance(self.model_params.summary_override, str)
else (
self.model_params.summary_override.provide(
self.agent, system_instructions, input
)
)
)
elif self.agent:
summary = self.agent.name
else:
summary = None

return await workflow.execute_activity_method(
ModelActivity.invoke_model_activity,
activity_input,
summary=self.model_params.summary_override or _extract_summary(input),
summary=summary,
task_queue=self.model_params.task_queue,
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
Expand Down
143 changes: 143 additions & 0 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
OpenAIChatCompletionsModel,
OpenAIResponsesModel,
OutputGuardrailTripwireTriggered,
RunConfig,
RunContextWrapper,
Runner,
SQLiteSession,
Expand Down Expand Up @@ -85,6 +86,7 @@
TestModel,
TestModelProvider,
)
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, CancelledError
Expand Down Expand Up @@ -2041,3 +2043,144 @@ async def test_hosted_mcp_tool(client: Client, use_local_model):
result = await workflow_handle.result()
if use_local_model:
assert result == "Some language"


class AssertDifferentModelProvider(ModelProvider):
model_names: set[Optional[str]]

def __init__(self, model: Model):
self._model = model
self.model_names = set()

def get_model(self, model_name: Union[str, None]) -> Model:
self.model_names.add(model_name)
return self._model


class MultipleModelsModel(StaticTestModel):
responses = [
ResponseBuilders.tool_call("{}", "transfer_to_underling"),
ResponseBuilders.output_message(
"I'm here to help! Was there a specific task you needed assistance with regarding the storeroom?"
),
]


@workflow.defn
class MultipleModelWorkflow:
@workflow.run
async def run(self, use_run_config: bool):
underling = Agent[None](
name="Underling",
instructions="You do all the work you are told.",
)

starting_agent = Agent[None](
name="Lazy Assistant",
model="gpt-4o-mini",
instructions="You delegate all your work to another agent.",
handoffs=[underling],
)
result = await Runner.run(
starting_agent=starting_agent,
input="Have you cleaned the store room yet?",
run_config=RunConfig(model="gpt-4o") if use_run_config else None,
)
return result.final_output


async def test_multiple_models(client: Client):
provider = AssertDifferentModelProvider(MultipleModelsModel())
new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=120)
),
model_provider=provider,
)
]
client = Client(**new_config)

async with new_worker(
client,
MultipleModelWorkflow,
) as worker:
workflow_handle = await client.start_workflow(
MultipleModelWorkflow.run,
False,
id=f"multiple-model-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
result = await workflow_handle.result()
assert provider.model_names == {None, "gpt-4o-mini"}


async def test_run_config_models(client: Client):
provider = AssertDifferentModelProvider(MultipleModelsModel())
new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=120)
),
model_provider=provider,
)
]
client = Client(**new_config)

async with new_worker(
client,
MultipleModelWorkflow,
) as worker:
workflow_handle = await client.start_workflow(
MultipleModelWorkflow.run,
True,
id=f"run-config-model-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
result = await workflow_handle.result()

# Only the model from the runconfig override is used
assert provider.model_names == {"gpt-4o"}


async def test_summary_provider(client: Client):
class SummaryProvider(ModelSummaryProvider):
def provide(
self,
agent: Optional[Agent[Any]],
instructions: Optional[str],
input: Union[str, list[TResponseInputItem]],
) -> str:
return "My summary"

new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=120),
summary_override=SummaryProvider(),
),
model_provider=TestModelProvider(TestHelloModel()),
)
]
client = Client(**new_config)

async with new_worker(
client,
HelloWorldAgent,
) as worker:
workflow_handle = await client.start_workflow(
HelloWorldAgent.run,
"Prompt",
id=f"summary-provider-model-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
result = await workflow_handle.result()
async for e in workflow_handle.fetch_history_events():
if e.HasField("activity_task_scheduled_event_attributes"):
assert e.user_metadata.summary.data == b'"My summary"'
Loading