Skip to content

OpenAI agents prototype #896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Support for running OpenAI agents as part of Temporal workflows."""
37 changes: 37 additions & 0 deletions temporalio/contrib/openai_agents/_heartbeat_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import asyncio
from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar, cast

from temporalio import activity

F = TypeVar("F", bound=Callable[..., Awaitable[Any]])


def _auto_heartbeater(fn: F) -> F:
# We want to ensure that the type hints from the original callable are
# available via our wrapper, so we use the functools wraps decorator
@wraps(fn)
async def wrapper(*args, **kwargs):
heartbeat_timeout = activity.info().heartbeat_timeout
heartbeat_task = None
if heartbeat_timeout:
# Heartbeat twice as often as the timeout
heartbeat_task = asyncio.create_task(
heartbeat_every(heartbeat_timeout.total_seconds() / 2)
)
try:
return await fn(*args, **kwargs)
finally:
if heartbeat_task:
heartbeat_task.cancel()
# Wait for heartbeat cancellation to complete
await asyncio.wait([heartbeat_task])

return cast(F, wrapper)


async def heartbeat_every(delay: float, *details: Any) -> None:
"""Heartbeat every so often while not cancelled"""
while True:
await asyncio.sleep(delay)
activity.heartbeat(*details)
126 changes: 126 additions & 0 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from dataclasses import replace

from agents import (
Agent,
RunConfig,
RunHooks,
Runner,
RunResult,
RunResultStreaming,
TContext,
TResponseInputItem,
)
from agents.run import DEFAULT_MAX_TURNS, DEFAULT_RUNNER, DefaultRunner

from temporalio import workflow
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub

# TODO: Uncomment when Agent.tools type accepts Callable
# def _activities_as_tools(tools: list[Tool]) -> list[Tool]:
# """Convert activities to tools."""
# return [activity_as_tool(tool) if isinstance(tool, Callable) else tool for tool in tools]


class TemporalOpenAIRunner(Runner):
"""Temporal Runner for OpenAI agents.

Forwards model calls to a Temporal activity.

TODO: Implement original runner forwarding
"""

def __init__(self):
"""Initialize the Temporal OpenAI Runner."""
self._runner = DEFAULT_RUNNER or DefaultRunner()

async def _run_impl(
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResult:
"""Run the agent in a Temporal workflow."""
if not workflow.in_workflow():
return await self._runner._run_impl(
starting_agent,
input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
)
if run_config is None:
run_config = RunConfig()

if run_config.model is not None and not isinstance(run_config.model, str):
raise ValueError(
"Temporal workflows require a model name to be a string in the run config."
)
updated_run_config = replace(
run_config, model=_TemporalModelStub(run_config.model)
)

# TODO: Uncomment when Agent.tools type accepts Callable
# tools = _activities_as_tools(starting_agent.tools) if starting_agent.tools else None
# updated_starting_agent = replace(starting_agent, tools=tools)

return await self._runner._run_impl(
starting_agent=starting_agent,
input=input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=updated_run_config,
previous_response_id=previous_response_id,
)

def _run_sync_impl(
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResult:
if not workflow.in_workflow():
return self._runner._run_sync_impl(
starting_agent,
input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
)
raise RuntimeError("Temporal workflows do not support synchronous model calls.")

def _run_streamed_impl(
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResultStreaming:
if not workflow.in_workflow():
return self._runner._run_streamed_impl(
starting_agent,
input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
)
raise RuntimeError("Temporal workflows do not support streaming.")
157 changes: 157 additions & 0 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import annotations

from temporalio import workflow

with workflow.unsafe.imports_passed_through():
from datetime import timedelta
from typing import Any, AsyncIterator, Sequence, cast

from agents import (
AgentOutputSchema,
AgentOutputSchemaBase,
ComputerTool,
FileSearchTool,
FunctionTool,
Handoff,
Model,
ModelResponse,
ModelSettings,
ModelTracing,
Tool,
TResponseInputItem,
WebSearchTool,
)
from agents.items import TResponseStreamEvent

from temporalio.contrib.openai_agents.invoke_model_activity import (
ActivityModelInput,
AgentOutputSchemaInput,
FunctionToolInput,
HandoffInput,
ModelTracingInput,
ToolInput,
invoke_model_activity,
)


class _TemporalModelStub(Model):
"""A stub that allows invoking models as Temporal activities."""

def __init__(self, model_name: str | None) -> None:
self.model_name = model_name

async def get_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: str | None,
) -> ModelResponse:
def get_summary(input: str | list[TResponseInputItem]) -> str:
### Activity summary shown in the UI
try:
max_size = 100
if isinstance(input, str):
return input[:max_size]
elif isinstance(input, list):
seq_input = cast(Sequence[Any], input)
last_item = seq_input[-1]
if isinstance(last_item, dict):
return last_item.get("content", "")[:max_size]
elif hasattr(last_item, "content"):
return str(getattr(last_item, "content"))[:max_size]
return str(last_item)[:max_size]
elif isinstance(input, dict):
return input.get("content", "")[:max_size]
except Exception as e:
print(f"Error getting summary: {e}")
return ""

def make_tool_info(tool: Tool) -> ToolInput:
if isinstance(tool, FileSearchTool):
return cast(FileSearchTool, tool)
elif isinstance(tool, WebSearchTool):
return cast(WebSearchTool, tool)
Comment on lines +76 to +79
Copy link
Member

@Sushisource Sushisource Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all a bit suspicious to me. cast doesn't actually do anything except signal to the type checker that a type is a certain type. I don't think this would survive over the serialization boundary.

The corresponding bit here https://github.com/temporalio/sdk-python/pull/896/files#diff-4d31f8abb21e7ac3886618347fb821f5a2679c41146ce9363c20166b913b0bc2R137 is similarly a bit funky. I think the isinstance checks after deserialization are made to work via the custom converter: https://github.com/temporalio/sdk-python/pull/896/files#diff-7307f8c289cf7271f54f43ba96cdd40af15e7aacf1989d6d27785dfbeb62ff5fR67

So maybe the casting is just to ensure the converter does the right thing, but, it reads a bit funny.

elif isinstance(tool, ComputerTool):
raise NotImplementedError(
"Computer search preview is not supported in Temporal model"
)
elif isinstance(tool, FunctionTool):
t = cast(FunctionToolInput, tool)
return FunctionToolInput(
name=t.name,
description=t.description,
params_json_schema=t.params_json_schema,
strict_json_schema=t.strict_json_schema,
)
else:
raise ValueError(f"Unknown tool type: {tool.name}")

tool_infos = [make_tool_info(x) for x in tools]
handoff_infos = [
HandoffInput(
tool_name=x.tool_name,
tool_description=x.tool_description,
input_json_schema=x.input_json_schema,
agent_name=x.agent_name,
strict_json_schema=x.strict_json_schema,
)
for x in handoffs
]
if output_schema is not None and not isinstance(
output_schema, AgentOutputSchema
):
raise TypeError(
f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}"
)
agent_output_schema = cast(AgentOutputSchema, output_schema)
output_schema_input = (
None
if agent_output_schema is None
else AgentOutputSchemaInput(
output_type_name=agent_output_schema.name(),
is_wrapped=agent_output_schema._is_wrapped,
output_schema=agent_output_schema.json_schema()
if not agent_output_schema.is_plain_text()
else None,
strict_json_schema=agent_output_schema.is_strict_json_schema(),
)
)

activity_input = ActivityModelInput(
model_name=self.model_name,
system_instructions=system_instructions,
input=input,
model_settings=model_settings,
tools=tool_infos,
output_schema=output_schema_input,
handoffs=handoff_infos,
tracing=ModelTracingInput(tracing.value),
previous_response_id=previous_response_id,
)
return await workflow.execute_activity(
invoke_model_activity,
activity_input,
start_to_close_timeout=timedelta(seconds=60),
heartbeat_timeout=timedelta(seconds=10),
summary=get_summary(input),
)

def stream_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: str | None,
) -> AsyncIterator[TResponseStreamEvent]:
raise NotImplementedError("Temporal model doesn't support streams yet")
Loading