Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/kind-parrots-help.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-google": patch
---

Gemini Realtime: Transcribe model audio via gemini api & use latest model as default for google plugin
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..._utils import _build_gemini_ctx, _build_tools

LiveAPIModels = Literal["gemini-2.0-flash-exp"]
LiveAPIModels = Literal["gemini-2.0-flash-001",]

Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
_build_gemini_ctx,
_build_tools,
)
from .transcriber import TranscriberSession, TranscriptionContent
from .transcriber import ModelTranscriber, TranscriberSession, TranscriptionContent

EventTypes = Literal[
"start_session",
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
self,
*,
instructions: str | None = None,
model: LiveAPIModels | str = "gemini-2.0-flash-exp",
model: LiveAPIModels | str = "gemini-2.0-flash-001",
api_key: str | None = None,
voice: Voice | str = "Puck",
modalities: list[Modality] = ["AUDIO"],
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
instructions (str, optional): Initial system instructions for the model. Defaults to "".
api_key (str or None, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-001".
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
Expand Down Expand Up @@ -301,7 +301,7 @@ def __init__(
)
self._transcriber.on("input_speech_done", self._on_input_speech_done)
if self._opts.enable_agent_audio_transcription:
self._agent_transcriber = TranscriberSession(
self._agent_transcriber = ModelTranscriber(
client=self._client, model=self._opts.model
)
self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
Expand Down Expand Up @@ -382,7 +382,7 @@ def _on_input_speech_done(self, content: TranscriptionContent) -> None:
# TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech

def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when interrupted, are we only transcribing until the moment of interruption?

Copy link
Contributor Author

@jayeshp19 jayeshp19 Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the current implementation transcribes the entire text (all the frames which are received before interruption). It's hard to determine the exact point of interruption since we receive frames faster than the actual playback.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. I think that's fine.. in the v1 branch, the synchronization/truncation logic will be downstream from the model.. model should just produce the entire thing.

if not self._is_interrupted and content.response_id and content.text:
if content.response_id and content.text:
self.emit(
"agent_speech_transcription_completed",
InputTranscription(
Expand Down Expand Up @@ -439,10 +439,12 @@ async def _recv_task():
// 2,
)
if self._opts.enable_agent_audio_transcription:
self._agent_transcriber._push_audio(frame)
content.audio.append(frame)
content.audio_stream.send_nowait(frame)

if server_content.interrupted or server_content.turn_complete:
if self._opts.enable_agent_audio_transcription:
self._agent_transcriber._push_audio(content.audio)
for stream in (content.text_stream, content.audio_stream):
if isinstance(stream, utils.aio.Chan):
stream.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,28 @@

import websockets
from livekit import rtc
from livekit.agents import utils
from livekit.agents import APIConnectionError, APIStatusError, utils

from google import genai
from google.genai import types
from google.genai.errors import APIError, ClientError, ServerError

from ...log import logger
from .api_proto import ClientEvents, LiveAPIModels

EventTypes = Literal[
"input_speech_started",
"input_speech_done",
]
EventTypes = Literal["input_speech_started", "input_speech_done"]

DEFAULT_LANGUAGE = "English"

SYSTEM_INSTRUCTIONS = f"""
You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text.

- Transcribe verbatim; exclude non-speech sounds.
- Provide only transcription; no extra text or explanations.
- If audio is unclear, respond with: `...`
- Ensure error-free transcription, preserving meaning and context.
- Use proper punctuation and formatting.
- Do not add explanations, comments, or extra information.
- Do not include timestamps, speaker labels, or annotations unless specified.

- Audio Language: {DEFAULT_LANGUAGE}
"""

Expand All @@ -44,30 +40,24 @@ class TranscriptionContent:


class TranscriberSession(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
client: genai.Client,
model: LiveAPIModels | str,
):
"""
Initializes a TranscriberSession instance for interacting with Google's Realtime API.
"""
"""
Handles live audio transcription using the realtime API.
"""

def __init__(self, *, client: genai.Client, model: LiveAPIModels | str):
super().__init__()
self._client = client
self._model = model
self._needed_sr = 16000
self._closed = False

system_instructions = types.Content(
parts=[types.Part(text=SYSTEM_INSTRUCTIONS)]
)

self._config = types.LiveConnectConfig(
response_modalities=["TEXT"],
system_instruction=system_instructions,
generation_config=types.GenerationConfig(
temperature=0.0,
),
generation_config=types.GenerationConfig(temperature=0.0),
)
self._main_atask = asyncio.create_task(
self._main_task(), name="gemini-realtime-transcriber"
Expand Down Expand Up @@ -187,6 +177,93 @@ async def _recv_task():
await self._session.close()


class ModelTranscriber(utils.EventEmitter[EventTypes]):
"""
Transcribes agent audio using model generation.
"""

def __init__(self, *, client: genai.Client, model: LiveAPIModels | str):
super().__init__()
self._client = client
self._model = model
self._needed_sr = 16000
self._system_instructions = types.Content(
parts=[types.Part(text=SYSTEM_INSTRUCTIONS)]
)
self._config = types.GenerateContentConfig(
temperature=0.0,
system_instruction=self._system_instructions,
# TODO: add response_schem
)
self._resampler: rtc.AudioResampler | None = None
self._buffer: rtc.AudioFrame | None = None
self._audio_ch = utils.aio.Chan[rtc.AudioFrame]()
self._main_atask = asyncio.create_task(
self._main_task(), name="gemini-model-transcriber"
)

async def aclose(self) -> None:
if self._audio_ch.closed:
return
self._audio_ch.close()
await self._main_atask

def _push_audio(self, frames: list[rtc.AudioFrame]) -> None:
if not frames:
return

buffer = utils.merge_frames(frames)

if buffer.sample_rate != self._needed_sr:
if self._resampler is None:
self._resampler = rtc.AudioResampler(
input_rate=buffer.sample_rate,
output_rate=self._needed_sr,
quality=rtc.AudioResamplerQuality.HIGH,
)

buffer = utils.merge_frames(self._resampler.push(buffer))

self._audio_ch.send_nowait(buffer)

@utils.log_exceptions(logger=logger)
async def _main_task(self):
request_id = utils.shortuuid()
try:
async for buffer in self._audio_ch:
# TODO: stream content for better latency
response = await self._client.aio.models.generate_content(
model=self._model,
contents=[
types.Content(
parts=[
types.Part(text=SYSTEM_INSTRUCTIONS),
types.Part.from_bytes(
data=buffer.to_wav_bytes(),
mime_type="audio/wav",
),
],
role="user",
)
],
config=self._config,
)
content = TranscriptionContent(
response_id=request_id, text=clean_transcription(response.text)
)
self.emit("input_speech_done", content)

except (ClientError, ServerError, APIError) as e:
raise APIStatusError(
f"model transcriber error: {e}",
status_code=e.code,
body=e.message,
request_id=request_id,
) from e
except Exception as e:
raise APIConnectionError("Error generating transcription") from e


def clean_transcription(text: str) -> str:
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LLM(llm.LLM):
def __init__(
self,
*,
model: ChatModels | str = "gemini-2.0-flash-exp",
model: ChatModels | str = "gemini-2.0-flash-001",
api_key: str | None = None,
vertexai: bool = False,
project: str | None = None,
Expand All @@ -85,7 +85,7 @@ def __init__(
- For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.

Args:
model (ChatModels | str, optional): The model name to use. Defaults to "gemini-2.0-flash-exp".
model (ChatModels | str, optional): The model name to use. Defaults to "gemini-2.0-flash-001".
api_key (str, optional): The API key for Google Gemini. If not provided, it attempts to read from the `GOOGLE_API_KEY` environment variable.
vertexai (bool, optional): Whether to use VertexAI. Defaults to False.
project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,9 @@

AudioEncoding = Literal["wav", "mp3", "ogg", "mulaw", "alaw", "linear16"]

ChatModels = Literal["gemini-2.0-flash-exp", "gemini-1.5-pro"]
ChatModels = Literal[
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-preview-02-05",
"gemini-2.0-pro-exp-02-05",
"gemini-1.5-pro",
]
30 changes: 19 additions & 11 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ async def toggle_light(
await asyncio.sleep(60)

# used to test arrays as arguments
@ai_callable(description="Select currencies of a specific area")
def select_currencies(
@ai_callable(description="Schedule recurring events on selected days")
def schedule_meeting(
self,
currencies: Annotated[
meeting_days: Annotated[
list[str],
TypeInfo(
description="The currencies to select",
choices=["usd", "eur", "gbp", "jpy", "sek"],
description="The days of the week on which meetings will occur",
choices=[
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
],
),
],
) -> None: ...
Expand Down Expand Up @@ -207,21 +215,21 @@ async def test_calls_arrays(llm_factory: Callable[[], llm.LLM]):

stream = await _request_fnc_call(
input_llm,
"Can you select all currencies in Europe at once from given choices using function call `select_currencies`?",
"can you schedule a meeting on monday and wednesday?",
fnc_ctx,
temperature=0.2,
)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls])
await stream.aclose()

assert len(calls) == 1, "select_currencies should have been called only once"
assert len(calls) == 1, "schedule_meeting should have been called only once"

call = calls[0]
currencies = call.call_info.arguments["currencies"]
assert len(currencies) == 3, "select_currencies should have 3 currencies"
assert "eur" in currencies and "gbp" in currencies and "sek" in currencies, (
"select_currencies should have eur, gbp, sek"
meeting_days = call.call_info.arguments["meeting_days"]
assert len(meeting_days) == 2, "schedule_meeting should have 2 days"
assert "monday" in meeting_days and "wednesday" in meeting_days, (
"meeting_days should have monday, wednesday"
)


Expand Down
Loading