Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 90 additions & 9 deletions temporalio/contrib/openai_agents/temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import json
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, AsyncIterator, Callable, Optional, Union, overload
from typing import Any, AsyncIterator, Callable, Optional, Type, Union

import nexusrpc
from agents import (
Agent,
AgentOutputSchemaBase,
Handoff,
Model,
Expand All @@ -19,20 +19,14 @@
TResponseInputItem,
set_trace_provider,
)
from agents.function_schema import DocstringStyle, function_schema
from agents.function_schema import function_schema
from agents.items import TResponseStreamEvent
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tool import (
FunctionTool,
ToolErrorFunction,
ToolFunction,
ToolParams,
default_tool_error_function,
function_tool,
)
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
from agents.util._types import MaybeAwaitable
from openai.types.responses import ResponsePromptParam

from temporalio import activity
Expand Down Expand Up @@ -266,3 +260,90 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
on_invoke_tool=run_activity,
strict_json_schema=True,
)

@classmethod
def nexus_operation_as_tool(
cls,
operation: nexusrpc.Operation[Any, Any],
*,
service: Type[Any],
endpoint: str,
schedule_to_close_timeout: Optional[timedelta] = None,
) -> Tool:
"""Convert a Nexus operation into an OpenAI agent tool.

.. warning::
This API is experimental and may change in future versions.
Use with caution in production environments.

This function takes a Nexus operation and converts it into an
OpenAI agent tool that can be used by the agent to execute the operation
during workflow execution. The tool will automatically handle the conversion
of inputs and outputs between the agent and the operation.

Args:
fn: A Nexus operation to convert into a tool.
service: The Nexus service class that contains the operation.
endpoint: The Nexus endpoint to use for the operation.

Returns:
An OpenAI agent tool that wraps the provided operation.

Example:
>>> @nexusrpc.service
... class WeatherService:
... get_weather_object_nexus_operation: nexusrpc.Operation[WeatherInput, Weather]
>>>
>>> # Create tool with custom activity options
>>> tool = nexus_operation_as_tool(
... WeatherService.get_weather_object_nexus_operation,
... service=WeatherService,
... endpoint="weather-service",
... )
>>> # Use tool with an OpenAI agent
"""

def operation_callable(input):
raise NotImplementedError("This function definition is used as a type only")

operation_callable.__annotations__ = {
"input": operation.input_type,
"return": operation.output_type,
}
operation_callable.__name__ = operation.name

schema = function_schema(operation_callable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support activities having access to an agent context as the first param but I am concerned Nexus operation limitations of single param forbid having context as the first param and input as the second. Suggestions? I think we may just have to document for now that tools backed by Nexus operations can't have access to context at this time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is ok since we provide nexus_operation_as_tool as a convenience method. You should still be able to define a regular function tool, grab what you need from the context, and invoke the Nexus operation from there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed behavior is ok, it's just the optics are confusing/inconsistent. "So, in-workflow tools can have mutable context, activity tools can have immutable context, and nexus tools can have no context?" We need to document this at least I think.


async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
json_data = json.loads(input)
except Exception as e:
raise ApplicationError(
f"Invalid JSON input for tool {schema.name}: {input}"
) from e

nexus_client = temporal_workflow.create_nexus_client(
service=service, endpoint=endpoint
)
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))
assert len(args) == 1, "Nexus operations must have exactly one argument"
[arg] = args
result = await nexus_client.execute_operation(
operation,
arg,
schedule_to_close_timeout=schedule_to_close_timeout,
)
try:
return str(result)
except Exception as e:
raise ToolSerializationError(
"You must return a string representation of the tool output, or something we can call str() on"
) from e

return FunctionTool(
name=schema.name,
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=run_operation,
strict_json_schema=True,
)
154 changes: 154 additions & 0 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import timedelta
from typing import Any, Optional, Union, no_type_check

import nexusrpc
import pytest
from agents import (
Agent,
Expand Down Expand Up @@ -59,10 +60,12 @@
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import CancelledError
from temporalio.testing import WorkflowEnvironment
from tests.contrib.openai_agents.research_agents.research_manager import (
ResearchManager,
)
from tests.helpers import new_worker
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name

response_index: int = 0

Expand Down Expand Up @@ -192,6 +195,22 @@ async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather
return Weather(city=city, temperature_range="14-20C", conditions=ctx.context)


@nexusrpc.service
class WeatherService:
get_weather_nexus_operation: nexusrpc.Operation[WeatherInput, Weather]


@nexusrpc.handler.service_handler(service=WeatherService)
class WeatherServiceHandler:
@nexusrpc.handler.sync_operation
async def get_weather_nexus_operation(
self, ctx: nexusrpc.handler.StartOperationContext, input: WeatherInput
) -> Weather:
return Weather(
city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
)


class TestWeatherModel(StaticTestModel):
responses = [
ModelResponse(
Expand Down Expand Up @@ -272,6 +291,44 @@ class TestWeatherModel(StaticTestModel):
]


class TestNexusWeatherModel(StaticTestModel):
responses = [
ModelResponse(
output=[
ResponseFunctionToolCall(
arguments='{"input":{"city":"Tokyo"}}',
call_id="call",
name="get_weather_nexus_operation",
type="function_call",
id="id",
status="completed",
)
],
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text="Test nexus weather result",
annotations=[],
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)
],
usage=Usage(),
response_id=None,
),
]


@workflow.defn
class ToolsWorkflow:
@workflow.run
Expand Down Expand Up @@ -300,6 +357,28 @@ async def run(self, question: str) -> str:
return result.final_output


@workflow.defn
class NexusToolsWorkflow:
@workflow.run
async def run(self, question: str) -> str:
agent = Agent(
name="Nexus Tools Workflow",
instructions="You are a helpful agent.",
tools=[
openai_agents.workflow.nexus_operation_as_tool(
WeatherService.get_weather_nexus_operation,
service=WeatherService,
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
schedule_to_close_timeout=timedelta(seconds=10),
),
],
) # type: Agent
result = await Runner.run(
starting_agent=agent, input=question, context="Stormy"
)
return result.final_output


@pytest.mark.parametrize("use_local_model", [True, False])
async def test_tool_workflow(client: Client, use_local_model: bool):
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
Expand Down Expand Up @@ -404,6 +483,81 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
)


@pytest.mark.parametrize("use_local_model", [True, False])
async def test_nexus_tool_workflow(
client: Client, env: WorkflowEnvironment, use_local_model: bool
):
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
pytest.skip("No openai API key")

if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)

model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestModelProvider(
TestNexusWeatherModel( # type: ignore
)
)
if use_local_model
else None
)
async with new_worker(
client,
NexusToolsWorkflow,
activities=[
model_activity.invoke_model_activity,
],
nexus_service_handlers=[WeatherServiceHandler()],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
await create_nexus_endpoint(worker.task_queue, client)

workflow_handle = await client.start_workflow(
NexusToolsWorkflow.run,
"What is the weather in Tokio?",
id=f"nexus-tools-workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=30),
)
result = await workflow_handle.result()

if use_local_model:
assert result == "Test nexus weather result"

events = []
async for e in workflow_handle.fetch_history_events():
if e.HasField(
"activity_task_completed_event_attributes"
) or e.HasField("nexus_operation_completed_event_attributes"):
events.append(e)

assert len(events) == 3
assert (
"function_call"
in events[0]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"Sunny with wind"
in events[
1
].nexus_operation_completed_event_attributes.result.data.decode()
)
assert (
"Test nexus weather result"
in events[2]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)


@no_type_check
class TestResearchModel(StaticTestModel):
responses = [
Expand Down
Loading