Skip to content

wip #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: rm/pr1045
Choose a base branch
from
Draft

wip #1046

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/realtime/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
import base64
import os
import sys
from typing import TYPE_CHECKING

import numpy as np

# Add the current directory to path so we can import ui
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from agents import function_tool
from agents.realtime import RealtimeAgent, RealtimeSession, RealtimeSessionEvent

if TYPE_CHECKING:
from .ui import AppUI
else:
# At runtime, try both import styles
try:
# Try relative import first (when used as a package)
from .ui import AppUI
except ImportError:
# Fall back to direct import (when run as a script)
from ui import AppUI


@function_tool
def get_weather(city: str) -> str:
"""Get the weather in a city."""
return f"The weather in {city} is sunny."


agent = RealtimeAgent(
name="Assistant",
instructions="You always greet the user with 'Top of the morning to you'.",
tools=[get_weather],
)


class Example:
def __init__(self) -> None:
self.session = RealtimeSession(agent)
self.ui = AppUI()
self.ui.connected = asyncio.Event()
self.ui.last_audio_item_id = None
# Set the audio callback
self.ui.set_audio_callback(self.on_audio_recorded)

async def run(self) -> None:
self.session.add_listener(self.on_event)
await self.session.connect()
self.ui.set_is_connected(True)
await self.ui.run_async()

async def on_audio_recorded(self, audio_bytes: bytes) -> None:
"""Called when audio is recorded by the UI."""
try:
# Send the audio to the session
await self.session.send_audio(audio_bytes)
except Exception as e:
self.ui.log_message(f"Error sending audio: {e}")

async def on_event(self, event: RealtimeSessionEvent) -> None:
# Display event in the UI
try:
if event.type == "raw_transport_event" and event.data.type == "other":
# self.ui.log_message(f"{event.data}, {type(event.data.data)}")
if event.data.data["type"] == "response.audio.delta":
self.ui.log_message("audio deltas")
delta_b64_string = event.data.data["delta"]
delta_bytes = base64.b64decode(delta_b64_string)
audio_data = np.frombuffer(delta_bytes, dtype=np.int16)
self.ui.play_audio(audio_data)

# Handle audio from model
if event.type == "audio":
try:
# Convert bytes to numpy array for audio player
audio_data = np.frombuffer(event.audio.data, dtype=np.int16)
self.ui.play_audio(audio_data)
except Exception as e:
self.ui.log_message(f"Audio play error: {e}")
except Exception:
# This can happen if the UI has already exited
pass


if __name__ == "__main__":
example = Example()
asyncio.run(example.run())
221 changes: 221 additions & 0 deletions examples/realtime/ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from __future__ import annotations

import asyncio
from collections.abc import Coroutine
from typing import Any, Callable

import numpy as np
import numpy.typing as npt
import sounddevice as sd
from textual import events
from textual.app import App, ComposeResult
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import RichLog, Static
from typing_extensions import override

CHUNK_LENGTH_S = 0.05 # 50ms
SAMPLE_RATE = 24000
FORMAT = np.int16
CHANNELS = 1


class Header(Static):
"""A header widget."""

@override
def render(self) -> str:
return "Realtime Demo"


class AudioStatusIndicator(Static):
"""A widget that shows the current audio recording status."""

is_recording = reactive(False)

@override
def render(self) -> str:
status = (
"🔴 Conversation started."
if self.is_recording
else "⚪ Press SPACE to start the conversation (q to quit)"
)
return status


class AppUI(App[None]):
CSS = """
Screen {
background: #1a1b26; /* Dark blue-grey background */
}

Container {
border: double rgb(91, 164, 91);
}

#input-container {
height: 5; /* Explicit height for input container */
margin: 1 1;
padding: 1 2;
}

#bottom-pane {
width: 100%;
height: 82%; /* Reduced to make room for session display */
border: round rgb(205, 133, 63);
content-align: center middle;
}

#status-indicator {
height: 3;
content-align: center middle;
background: #2a2b36;
border: solid rgb(91, 164, 91);
margin: 1 1;
}

#session-display {
height: 3;
content-align: center middle;
background: #2a2b36;
border: solid rgb(91, 164, 91);
margin: 1 1;
}

Static {
color: white;
}
"""

should_send_audio: asyncio.Event
connected: asyncio.Event
last_audio_item_id: str | None
audio_callback: Callable[[bytes], Coroutine[Any, Any, None]] | None

def __init__(self) -> None:
super().__init__()
self.audio_player = sd.OutputStream(
samplerate=SAMPLE_RATE,
channels=CHANNELS,
dtype=FORMAT,
)
self.should_send_audio = asyncio.Event()
self.connected = asyncio.Event()
self.audio_callback = None

@override
def compose(self) -> ComposeResult:
"""Create child widgets for the app."""
with Container():
yield Header(id="session-display")
yield AudioStatusIndicator(id="status-indicator")
yield RichLog(id="bottom-pane", wrap=True, highlight=True, markup=True)

def set_is_connected(self, is_connected: bool) -> None:
self.connected.set() if is_connected else self.connected.clear()

def set_audio_callback(self, callback: Callable[[bytes], Coroutine[Any, Any, None]]) -> None:
"""Set a callback function to be called when audio is recorded."""
self.audio_callback = callback

# High-level methods for UI operations
def set_header_text(self, text: str) -> None:
"""Update the header text."""
header = self.query_one("#session-display", Header)
header.update(text)

def set_recording_status(self, is_recording: bool) -> None:
"""Set the recording status indicator."""
status_indicator = self.query_one(AudioStatusIndicator)
status_indicator.is_recording = is_recording

def log_message(self, message: str) -> None:
"""Add a message to the log pane."""
try:
bottom_pane = self.query_one("#bottom-pane", RichLog)
bottom_pane.write(message)
except Exception:
# Handle the case where the widget might not be available
pass

def play_audio(self, audio_data: npt.NDArray[np.int16]) -> None:
"""Play audio data through the audio player."""
try:
self.audio_player.write(audio_data)
except Exception as e:
self.log_message(f"Audio play error: {e}")

async def on_mount(self) -> None:
"""Set up audio player and start the audio capture worker."""
self.audio_player.start()
self.run_worker(self.capture_audio())

async def capture_audio(self) -> None:
"""Capture audio from the microphone and send to the session."""
# Wait for connection to be established
await self.connected.wait()

self.log_message("Connected to agent. Press space to start the conversation")

# Set up audio input stream
stream = sd.InputStream(
channels=CHANNELS,
samplerate=SAMPLE_RATE,
dtype=FORMAT,
)

try:
# Wait for user to press spacebar to start
await self.should_send_audio.wait()

stream.start()
self.set_recording_status(True)
self.log_message("Recording started - speak to the agent")

# Buffer size in samples
read_size = int(SAMPLE_RATE * CHUNK_LENGTH_S)

while True:
# Check if there's enough data to read
if stream.read_available < read_size:
await asyncio.sleep(0.01) # Small sleep to avoid CPU hogging
continue

# Read audio data
data, _ = stream.read(read_size)

# Convert numpy array to bytes
audio_bytes = data.tobytes()

# Call audio callback if set
if self.audio_callback:
try:
await self.audio_callback(audio_bytes)
except Exception as e:
self.log_message(f"Audio callback error: {e}")

# Yield control back to event loop
await asyncio.sleep(0)

except Exception as e:
self.log_message(f"Audio capture error: {e}")
finally:
if stream.active:
stream.stop()
stream.close()

async def on_key(self, event: events.Key) -> None:
"""Handle key press events."""
# add the keypress to the log
self.log_message(f"Key pressed: {event.key}")

if event.key == "q":
self.audio_player.stop()
self.audio_player.close()
self.exit()
return

if event.key == "space": # Spacebar
if not self.should_send_audio.is_set():
self.should_send_audio.set()
self.set_recording_status(True)
54 changes: 27 additions & 27 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,33 @@ class AgentBase:
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
"""Configuration for MCP servers."""

async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(
self.mcp_servers, convert_schemas_to_strict, run_context, self
)

async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
mcp_tools = await self.get_mcp_tools(run_context)

async def _check_tool_enabled(tool: Tool) -> bool:
if not isinstance(tool, FunctionTool):
return True

attr = tool.is_enabled
if isinstance(attr, bool):
return attr
res = attr(run_context, self)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)

results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
return [*mcp_tools, *enabled]


@dataclass
class Agent(AgentBase, Generic[TContext]):
Expand Down Expand Up @@ -262,30 +289,3 @@ async def get_prompt(
) -> ResponsePromptParam | None:
"""Get the prompt for the agent."""
return await PromptUtil.to_model_input(self.prompt, run_context, self)

async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(
self.mcp_servers, convert_schemas_to_strict, run_context, self
)

async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
mcp_tools = await self.get_mcp_tools(run_context)

async def _check_tool_enabled(tool: Tool) -> bool:
if not isinstance(tool, FunctionTool):
return True

attr = tool.is_enabled
if isinstance(attr, bool):
return attr
res = attr(run_context, self)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)

results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
return [*mcp_tools, *enabled]
Loading
Loading