-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[draft] async text mode #4337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[draft] async text mode #4337
Changes from all commits
1c7b33e
8eba4e9
de63ced
fedadd3
e0f387b
4e2b715
d8bf768
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| import logging | ||
|
|
||
| from dotenv import load_dotenv | ||
|
|
||
| from livekit.agents import ( | ||
| Agent, | ||
| AgentServer, | ||
| AgentSession, | ||
| JobContext, | ||
| RunContext, | ||
| TextMessageContext, | ||
| cli, | ||
| ) | ||
| from livekit.agents.llm import function_tool | ||
| from livekit.plugins import silero | ||
| from livekit.plugins.turn_detector.multilingual import MultilingualModel | ||
|
|
||
| # uncomment to enable Krisp background voice/noise cancellation | ||
| # from livekit.plugins import noise_cancellation | ||
|
|
||
| logger = logging.getLogger("basic-agent") | ||
|
|
||
| load_dotenv() | ||
|
|
||
|
|
||
| class MyAgent(Agent): | ||
| def __init__(self, *, greet_on_enter: bool = True) -> None: | ||
| super().__init__( | ||
| instructions="Your name is Kelly. You would interact with users via voice." | ||
| "with that in mind keep your responses concise and to the point." | ||
| "do not use emojis, asterisks, markdown, or other special characters in your responses." | ||
| "You are curious and friendly, and have a sense of humor." | ||
| "you will speak english to the user", | ||
| ) | ||
| self._greet_on_enter = greet_on_enter | ||
|
|
||
| async def on_enter(self): | ||
| if self._greet_on_enter: | ||
| logger.debug("greeting the user") | ||
| self.session.generate_reply(allow_interruptions=False) | ||
|
|
||
| # all functions annotated with @function_tool will be passed to the LLM when this | ||
| # agent is active | ||
| @function_tool | ||
| async def lookup_weather( | ||
| self, context: RunContext, location: str, latitude: str, longitude: str | ||
| ): | ||
| """Called when the user asks for weather related information. | ||
| Ensure the user's location (city or region) is provided. | ||
| When given a location, please estimate the latitude and longitude of the location and | ||
| do not ask the user for them. | ||
| Args: | ||
| location: The location they are asking for | ||
| latitude: The latitude of the location, do not ask user for it | ||
| longitude: The longitude of the location, do not ask user for it | ||
| """ | ||
|
|
||
| logger.info(f"Looking up weather for {location}") | ||
|
|
||
| return "sunny with a temperature of 70 degrees." | ||
|
|
||
|
|
||
| server = AgentServer() | ||
|
|
||
|
|
||
| @server.sms_handler() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's position this as a text_handler, instead of sms. getting text right would mean this is accessible on LK Cloud: |
||
| async def sms_handler(ctx: TextMessageContext): | ||
| logger.info(f"SMS received: {ctx.text}") | ||
|
|
||
| session = AgentSession(llm="openai/gpt-4.1-mini") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we let user to setup some config for sms handler here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And maybe: |
||
| if ctx.session_data: | ||
| await session.rehydrate(ctx.session_data) | ||
| else: | ||
| await session.start(agent=MyAgent(greet_on_enter=False)) | ||
|
|
||
| result = await session.run(user_input=ctx.text) | ||
|
|
||
| await ctx.send_result(result) | ||
| await ctx.save_session(session) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: how do you feel about automatically saving with a context manager? async with ctx.resume(session=session, session_data=session_data, agent=...):
...
so users don't have to call ctx.save_session explicitly?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for
It feels out of sync to me if we have to use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree it’s a bit odd to have this “out of sync” issue. I think the agent code shouldn’t be sending SMS outside the sms_handler tho. Just for separation of concerns. |
||
|
|
||
|
|
||
| @server.rtc_session() | ||
| async def entrypoint(ctx: JobContext): | ||
| session = AgentSession( | ||
| stt="deepgram/nova-3", | ||
| llm="openai/gpt-4.1-mini", | ||
| tts="cartesia/sonic-2:9626c31c-bec5-4cca-baa8-f8ba9e84c8bc", | ||
| turn_detection=MultilingualModel(), | ||
| vad=silero.VAD.load(), | ||
| preemptive_generation=True, | ||
| ) | ||
|
|
||
| await session.start(agent=MyAgent(), room=ctx.room) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| cli.run_app(server) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,11 +13,12 @@ | |
| import re | ||
| import signal | ||
| import sys | ||
| import tempfile | ||
| import textwrap | ||
| import threading | ||
| import time | ||
| import traceback | ||
| from collections.abc import Iterator | ||
| from collections.abc import Awaitable, Iterator | ||
| from contextlib import contextmanager | ||
| from types import FrameType | ||
| from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, Optional, Union | ||
|
|
@@ -37,7 +38,7 @@ | |
| from livekit import rtc | ||
|
|
||
| from .._exceptions import CLIError | ||
| from ..job import JobExecutorType | ||
| from ..job import JobExecutorType, TextMessageContext | ||
| from ..log import logger | ||
| from ..plugin import Plugin | ||
| from ..utils import aio | ||
|
|
@@ -299,7 +300,7 @@ def __init__(self) -> None: | |
| self._console_directory, f"session-{datetime.datetime.now().strftime('%m-%d-%H%M%S')}" | ||
| ) | ||
|
|
||
| def acquire_io(self, *, loop: asyncio.AbstractEventLoop, session: AgentSession) -> None: | ||
| def acquire_io(self, *, loop: asyncio.AbstractEventLoop, session: AgentSession | None) -> None: | ||
| with self._lock: | ||
| if self._io_acquired: | ||
| raise RuntimeError("the ConsoleIO was already acquired by another session") | ||
|
|
@@ -317,9 +318,10 @@ def acquire_io(self, *, loop: asyncio.AbstractEventLoop, session: AgentSession) | |
| self._io_acquired_event.set() | ||
| self._io_session = session | ||
|
|
||
| self._update_sess_io( | ||
| session, self.console_mode, self._io_audio_input, self._io_audio_output | ||
| ) | ||
| if session: | ||
| self._update_sess_io( | ||
| session, self.console_mode, self._io_audio_input, self._io_audio_output | ||
| ) | ||
|
|
||
| @property | ||
| def enabled(self) -> bool: | ||
|
|
@@ -348,7 +350,7 @@ def io_acquired(self) -> bool: | |
|
|
||
| @property | ||
| def io_session(self) -> AgentSession: | ||
| if not self._io_acquired: | ||
| if not self._io_acquired or not self._io_session: | ||
| raise RuntimeError("AgentsConsole is not acquired") | ||
|
|
||
| return self._io_session | ||
|
|
@@ -985,8 +987,17 @@ def update(new_text: str | Text | None = None) -> None: | |
| yield update | ||
|
|
||
|
|
||
| def _text_mode(c: AgentsConsole) -> None: | ||
| def _text_mode( | ||
| c: AgentsConsole, | ||
| *, | ||
| sms_handler: Callable[[TextMessageContext], Awaitable[None]] | None = None, | ||
| sess_data_file: str | None = None, | ||
| ) -> None: | ||
| def _key_read(ch: str) -> None: | ||
| if sms_handler: | ||
| # sms console doesn't support toggling mode | ||
| return | ||
|
|
||
| if ch == key.CTRL_T: | ||
| raise _ToggleMode() | ||
|
|
||
|
|
@@ -1006,8 +1017,34 @@ def _key_read(ch: str) -> None: | |
|
|
||
| def _generate_with_context(text: str, result_fut: asyncio.Future[list[RunEvent]]) -> None: | ||
| async def _generate(text: str) -> list[RunEvent]: | ||
| sess = await c.io_session.run(user_input=text) # type: ignore | ||
| return sess.events.copy() | ||
| if sms_handler is not None: | ||
| # simulate a sms received event | ||
| assert sess_data_file | ||
|
|
||
| session_data: bytes | None = None | ||
| if os.path.isfile(sess_data_file): | ||
| with open(sess_data_file, "rb") as f: | ||
| session_data = f.read() | ||
|
|
||
| text_context = TextMessageContext(text=text, session_data=session_data) | ||
| await sms_handler(text_context) | ||
|
|
||
| # serialize the state of the session | ||
| if text_context.session_data: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: how about we handle this inside ctx.save_session? |
||
| with open(sess_data_file, "wb") as f: | ||
| f.write(text_context.session_data) | ||
| logger.debug( | ||
| "session state serialized", extra={"session_data_file": sess_data_file} | ||
| ) | ||
|
|
||
| result = text_context.result | ||
| if result is None: | ||
| logger.warning("result is not set from the sms handler") | ||
| return [] | ||
| else: | ||
| result = await c.io_session.run(user_input=text) | ||
|
|
||
| return result.events.copy() | ||
|
|
||
| def _done_callback(task: asyncio.Task[list[RunEvent]]) -> None: | ||
| if exception := task.exception(): | ||
|
|
@@ -1125,8 +1162,9 @@ def _listen_for_toggle() -> None: | |
|
|
||
|
|
||
| class _ConsoleWorker: | ||
| def __init__(self, *, server: AgentServer, shutdown_cb: Callable) -> None: | ||
| def __init__(self, *, server: AgentServer, shutdown_cb: Callable, sms_job: bool) -> None: | ||
| self._loop = asyncio.new_event_loop() | ||
| self._sms_job = sms_job | ||
| self._server = server | ||
| self._shutdown_cb = shutdown_cb | ||
| self._lock = threading.Lock() | ||
|
|
@@ -1158,7 +1196,10 @@ async def _async_main() -> None: | |
| def _simulate_job() -> None: | ||
| asyncio.run_coroutine_threadsafe( | ||
| self._server.simulate_job( | ||
| "console-room", agent_identity="console", fake_job=True | ||
| "console-room", | ||
| agent_identity="console", | ||
| fake_job=True, | ||
| sms_job=self._sms_job, | ||
| ), | ||
| self._loop, | ||
| ) | ||
|
|
@@ -1225,7 +1266,9 @@ def _handle_exit(sig: int, frame: FrameType | None) -> None: | |
| for sig in HANDLED_SIGNALS: | ||
| signal.signal(sig, _handle_exit) | ||
|
|
||
| console_worker = _ConsoleWorker(server=server, shutdown_cb=_on_worker_shutdown) | ||
| console_worker = _ConsoleWorker( | ||
| server=server, shutdown_cb=_on_worker_shutdown, sms_job=False | ||
| ) | ||
| console_worker.start() | ||
|
|
||
| # TODO: wait for a session request the agents console context before showing any of the mode | ||
|
|
@@ -1255,6 +1298,63 @@ def _handle_exit(sig: int, frame: FrameType | None) -> None: | |
| raise typer.Exit(code=1) from None | ||
|
|
||
|
|
||
| def _run_sms_console(*, server: AgentServer, sess_data_file: str) -> None: | ||
| c = AgentsConsole.get_instance() | ||
| c.console_mode = "text" | ||
| c.enabled = True | ||
|
|
||
| _configure_logger(c, logging.DEBUG) | ||
| c.print("Starting SMS console mode 🚀", tag="Agents") | ||
|
|
||
| c.print(" ") | ||
| try: | ||
| exit_triggered = False | ||
|
|
||
| def _on_worker_shutdown() -> None: | ||
| try: | ||
| signal.raise_signal(signal.SIGTERM) | ||
| except Exception: | ||
| try: | ||
| signal.raise_signal(signal.SIGINT) | ||
| except Exception: | ||
| pass | ||
|
|
||
| def _handle_exit(sig: int, frame: FrameType | None) -> None: | ||
| nonlocal exit_triggered | ||
| if not exit_triggered: | ||
| exit_triggered = True | ||
| raise _ExitCli() | ||
|
|
||
| console_worker.shutdown() | ||
|
|
||
| for sig in HANDLED_SIGNALS: | ||
| signal.signal(sig, _handle_exit) | ||
|
|
||
| if not server._sms_handler_fnc: | ||
| raise ValueError("sms_handler is required when simulating SMS") | ||
| sms_handler = server._sms_handler_fnc | ||
|
|
||
| console_worker = _ConsoleWorker( | ||
| server=server, shutdown_cb=_on_worker_shutdown, sms_job=True | ||
| ) | ||
| console_worker.start() | ||
|
|
||
| try: | ||
| c.wait_for_io_acquisition() | ||
| _text_mode(c, sms_handler=sms_handler, sess_data_file=sess_data_file) | ||
| except _ExitCli: | ||
| pass | ||
| finally: | ||
| console_worker.shutdown() | ||
| console_worker.join() | ||
|
|
||
| except CLIError as e: | ||
| c.print(" ") | ||
| c.print(f"[error]{e}") | ||
| c.print(" ") | ||
| raise typer.Exit(code=1) from None | ||
|
|
||
|
|
||
| def _run_worker(server: AgentServer, args: proto.CliArgs, jupyter: bool = False) -> None: | ||
| c: AgentsConsole | None = None | ||
| if args.devmode: | ||
|
|
@@ -1393,6 +1493,24 @@ def console( | |
| record=record, | ||
| ) | ||
|
|
||
| @app.command() | ||
| def sms_console( | ||
| *, | ||
| sess_data_file: Annotated[ | ||
| Optional[str], # noqa: UP007 | ||
| typer.Option(help="Path to the serialized AgentSession data file in SMS mode"), | ||
| ] = None, | ||
| ) -> None: | ||
| temp_dir: tempfile.TemporaryDirectory | None = None | ||
| if not sess_data_file: | ||
| temp_dir = tempfile.TemporaryDirectory(prefix="lk_", delete=False) | ||
| sess_data_file = os.path.join(temp_dir.name, "session_data.pkl") | ||
| try: | ||
| _run_sms_console(server=server, sess_data_file=sess_data_file) | ||
| finally: | ||
| if temp_dir: | ||
| temp_dir.cleanup() | ||
|
|
||
| @app.command() | ||
| def start( | ||
| *, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add a function called on_rehydrate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is technically going to be the default Python def
__setstate__method.