-
Notifications
You must be signed in to change notification settings - Fork 103
OpenAI agents prototype #896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
dandavison
wants to merge
6
commits into
temporalio:main
Choose a base branch
from
mfateev:openai-agents
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
15ebc0d
Added openai_agents.
mfateev 1d1faf6
added tools and trace interceptor.
mfateev a2b517a
Lint error fixes
mfateev 435eccc
Lint error fixes
mfateev c1a4971
Missing docstrings added.
mfateev d9669f7
Fixed tool serialization.
mfateev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Support for running OpenAI agents as part of Temporal workflows.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import asyncio | ||
from functools import wraps | ||
from typing import Any, Awaitable, Callable, TypeVar, cast | ||
|
||
from temporalio import activity | ||
|
||
F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) | ||
|
||
|
||
def _auto_heartbeater(fn: F) -> F: | ||
# We want to ensure that the type hints from the original callable are | ||
# available via our wrapper, so we use the functools wraps decorator | ||
@wraps(fn) | ||
async def wrapper(*args, **kwargs): | ||
heartbeat_timeout = activity.info().heartbeat_timeout | ||
heartbeat_task = None | ||
if heartbeat_timeout: | ||
# Heartbeat twice as often as the timeout | ||
heartbeat_task = asyncio.create_task( | ||
heartbeat_every(heartbeat_timeout.total_seconds() / 2) | ||
) | ||
try: | ||
return await fn(*args, **kwargs) | ||
finally: | ||
if heartbeat_task: | ||
heartbeat_task.cancel() | ||
# Wait for heartbeat cancellation to complete | ||
await asyncio.wait([heartbeat_task]) | ||
|
||
return cast(F, wrapper) | ||
|
||
|
||
async def heartbeat_every(delay: float, *details: Any) -> None: | ||
"""Heartbeat every so often while not cancelled""" | ||
while True: | ||
await asyncio.sleep(delay) | ||
activity.heartbeat(*details) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from dataclasses import replace | ||
|
||
from agents import ( | ||
Agent, | ||
RunConfig, | ||
RunHooks, | ||
Runner, | ||
RunResult, | ||
RunResultStreaming, | ||
TContext, | ||
TResponseInputItem, | ||
) | ||
from agents.run import DEFAULT_MAX_TURNS, DEFAULT_RUNNER, DefaultRunner | ||
|
||
from temporalio import workflow | ||
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub | ||
|
||
# TODO: Uncomment when Agent.tools type accepts Callable | ||
# def _activities_as_tools(tools: list[Tool]) -> list[Tool]: | ||
# """Convert activities to tools.""" | ||
# return [activity_as_tool(tool) if isinstance(tool, Callable) else tool for tool in tools] | ||
|
||
|
||
class TemporalOpenAIRunner(Runner): | ||
"""Temporal Runner for OpenAI agents. | ||
|
||
Forwards model calls to a Temporal activity. | ||
|
||
TODO: Implement original runner forwarding | ||
""" | ||
|
||
def __init__(self): | ||
"""Initialize the Temporal OpenAI Runner.""" | ||
self._runner = DEFAULT_RUNNER or DefaultRunner() | ||
|
||
async def _run_impl( | ||
self, | ||
starting_agent: Agent[TContext], | ||
input: str | list[TResponseInputItem], | ||
*, | ||
context: TContext | None = None, | ||
max_turns: int = DEFAULT_MAX_TURNS, | ||
hooks: RunHooks[TContext] | None = None, | ||
run_config: RunConfig | None = None, | ||
previous_response_id: str | None = None, | ||
) -> RunResult: | ||
"""Run the agent in a Temporal workflow.""" | ||
if not workflow.in_workflow(): | ||
return await self._runner._run_impl( | ||
starting_agent, | ||
input, | ||
context=context, | ||
max_turns=max_turns, | ||
hooks=hooks, | ||
run_config=run_config, | ||
previous_response_id=previous_response_id, | ||
) | ||
if run_config is None: | ||
run_config = RunConfig() | ||
|
||
if run_config.model is not None and not isinstance(run_config.model, str): | ||
raise ValueError( | ||
"Temporal workflows require a model name to be a string in the run config." | ||
) | ||
updated_run_config = replace( | ||
run_config, model=_TemporalModelStub(run_config.model) | ||
) | ||
|
||
# TODO: Uncomment when Agent.tools type accepts Callable | ||
# tools = _activities_as_tools(starting_agent.tools) if starting_agent.tools else None | ||
# updated_starting_agent = replace(starting_agent, tools=tools) | ||
|
||
return await self._runner._run_impl( | ||
starting_agent=starting_agent, | ||
input=input, | ||
context=context, | ||
max_turns=max_turns, | ||
hooks=hooks, | ||
run_config=updated_run_config, | ||
previous_response_id=previous_response_id, | ||
) | ||
|
||
def _run_sync_impl( | ||
self, | ||
starting_agent: Agent[TContext], | ||
input: str | list[TResponseInputItem], | ||
*, | ||
context: TContext | None = None, | ||
max_turns: int = DEFAULT_MAX_TURNS, | ||
hooks: RunHooks[TContext] | None = None, | ||
run_config: RunConfig | None = None, | ||
previous_response_id: str | None = None, | ||
) -> RunResult: | ||
if not workflow.in_workflow(): | ||
return self._runner._run_sync_impl( | ||
starting_agent, | ||
input, | ||
context=context, | ||
max_turns=max_turns, | ||
hooks=hooks, | ||
run_config=run_config, | ||
previous_response_id=previous_response_id, | ||
) | ||
raise RuntimeError("Temporal workflows do not support synchronous model calls.") | ||
|
||
def _run_streamed_impl( | ||
self, | ||
starting_agent: Agent[TContext], | ||
input: str | list[TResponseInputItem], | ||
context: TContext | None = None, | ||
max_turns: int = DEFAULT_MAX_TURNS, | ||
hooks: RunHooks[TContext] | None = None, | ||
run_config: RunConfig | None = None, | ||
previous_response_id: str | None = None, | ||
) -> RunResultStreaming: | ||
if not workflow.in_workflow(): | ||
return self._runner._run_streamed_impl( | ||
starting_agent, | ||
input, | ||
context=context, | ||
max_turns=max_turns, | ||
hooks=hooks, | ||
run_config=run_config, | ||
previous_response_id=previous_response_id, | ||
) | ||
raise RuntimeError("Temporal workflows do not support streaming.") |
157 changes: 157 additions & 0 deletions
157
temporalio/contrib/openai_agents/_temporal_model_stub.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from __future__ import annotations | ||
|
||
from temporalio import workflow | ||
|
||
with workflow.unsafe.imports_passed_through(): | ||
from datetime import timedelta | ||
from typing import Any, AsyncIterator, Sequence, cast | ||
|
||
from agents import ( | ||
AgentOutputSchema, | ||
AgentOutputSchemaBase, | ||
ComputerTool, | ||
FileSearchTool, | ||
FunctionTool, | ||
Handoff, | ||
Model, | ||
ModelResponse, | ||
ModelSettings, | ||
ModelTracing, | ||
Tool, | ||
TResponseInputItem, | ||
WebSearchTool, | ||
) | ||
from agents.items import TResponseStreamEvent | ||
|
||
from temporalio.contrib.openai_agents.invoke_model_activity import ( | ||
ActivityModelInput, | ||
AgentOutputSchemaInput, | ||
FunctionToolInput, | ||
HandoffInput, | ||
ModelTracingInput, | ||
ToolInput, | ||
invoke_model_activity, | ||
) | ||
|
||
|
||
class _TemporalModelStub(Model): | ||
"""A stub that allows invoking models as Temporal activities.""" | ||
|
||
def __init__(self, model_name: str | None) -> None: | ||
self.model_name = model_name | ||
|
||
async def get_response( | ||
self, | ||
system_instructions: str | None, | ||
input: str | list[TResponseInputItem], | ||
model_settings: ModelSettings, | ||
tools: list[Tool], | ||
output_schema: AgentOutputSchemaBase | None, | ||
handoffs: list[Handoff], | ||
tracing: ModelTracing, | ||
*, | ||
previous_response_id: str | None, | ||
) -> ModelResponse: | ||
def get_summary(input: str | list[TResponseInputItem]) -> str: | ||
### Activity summary shown in the UI | ||
try: | ||
max_size = 100 | ||
if isinstance(input, str): | ||
return input[:max_size] | ||
elif isinstance(input, list): | ||
seq_input = cast(Sequence[Any], input) | ||
last_item = seq_input[-1] | ||
if isinstance(last_item, dict): | ||
return last_item.get("content", "")[:max_size] | ||
elif hasattr(last_item, "content"): | ||
return str(getattr(last_item, "content"))[:max_size] | ||
return str(last_item)[:max_size] | ||
elif isinstance(input, dict): | ||
return input.get("content", "")[:max_size] | ||
except Exception as e: | ||
print(f"Error getting summary: {e}") | ||
return "" | ||
|
||
def make_tool_info(tool: Tool) -> ToolInput: | ||
if isinstance(tool, FileSearchTool): | ||
return cast(FileSearchTool, tool) | ||
elif isinstance(tool, WebSearchTool): | ||
return cast(WebSearchTool, tool) | ||
elif isinstance(tool, ComputerTool): | ||
raise NotImplementedError( | ||
"Computer search preview is not supported in Temporal model" | ||
) | ||
elif isinstance(tool, FunctionTool): | ||
t = cast(FunctionToolInput, tool) | ||
return FunctionToolInput( | ||
name=t.name, | ||
description=t.description, | ||
params_json_schema=t.params_json_schema, | ||
strict_json_schema=t.strict_json_schema, | ||
) | ||
else: | ||
raise ValueError(f"Unknown tool type: {tool.name}") | ||
|
||
tool_infos = [make_tool_info(x) for x in tools] | ||
handoff_infos = [ | ||
HandoffInput( | ||
tool_name=x.tool_name, | ||
tool_description=x.tool_description, | ||
input_json_schema=x.input_json_schema, | ||
agent_name=x.agent_name, | ||
strict_json_schema=x.strict_json_schema, | ||
) | ||
for x in handoffs | ||
] | ||
if output_schema is not None and not isinstance( | ||
output_schema, AgentOutputSchema | ||
): | ||
raise TypeError( | ||
f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}" | ||
) | ||
agent_output_schema = cast(AgentOutputSchema, output_schema) | ||
output_schema_input = ( | ||
None | ||
if agent_output_schema is None | ||
else AgentOutputSchemaInput( | ||
output_type_name=agent_output_schema.name(), | ||
is_wrapped=agent_output_schema._is_wrapped, | ||
output_schema=agent_output_schema.json_schema() | ||
if not agent_output_schema.is_plain_text() | ||
else None, | ||
strict_json_schema=agent_output_schema.is_strict_json_schema(), | ||
) | ||
) | ||
|
||
activity_input = ActivityModelInput( | ||
model_name=self.model_name, | ||
system_instructions=system_instructions, | ||
input=input, | ||
model_settings=model_settings, | ||
tools=tool_infos, | ||
output_schema=output_schema_input, | ||
handoffs=handoff_infos, | ||
tracing=ModelTracingInput(tracing.value), | ||
previous_response_id=previous_response_id, | ||
) | ||
return await workflow.execute_activity( | ||
invoke_model_activity, | ||
activity_input, | ||
start_to_close_timeout=timedelta(seconds=60), | ||
heartbeat_timeout=timedelta(seconds=10), | ||
summary=get_summary(input), | ||
) | ||
|
||
def stream_response( | ||
self, | ||
system_instructions: str | None, | ||
input: str | list[TResponseInputItem], | ||
model_settings: ModelSettings, | ||
tools: list[Tool], | ||
output_schema: AgentOutputSchemaBase | None, | ||
handoffs: list[Handoff], | ||
tracing: ModelTracing, | ||
*, | ||
previous_response_id: str | None, | ||
) -> AsyncIterator[TResponseStreamEvent]: | ||
raise NotImplementedError("Temporal model doesn't support streams yet") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is all a bit suspicious to me.
cast
doesn't actually do anything except signal to the type checker that a type is a certain type. I don't think this would survive over the serialization boundary.The corresponding bit here https://github.com/temporalio/sdk-python/pull/896/files#diff-4d31f8abb21e7ac3886618347fb821f5a2679c41146ce9363c20166b913b0bc2R137 is similarly a bit funky. I think the
isinstance
checks after deserialization are made to work via the custom converter: https://github.com/temporalio/sdk-python/pull/896/files#diff-7307f8c289cf7271f54f43ba96cdd40af15e7aacf1989d6d27785dfbeb62ff5fR67So maybe the casting is just to ensure the converter does the right thing, but, it reads a bit funny.