Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/tasty-brooms-exercise.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-google": patch
"livekit-plugins-openai": patch
"livekit-agents": patch
---

Added an additional field in LLM capabilities class to check if model providers support function call history within chat context without needing function definitions.
13 changes: 11 additions & 2 deletions livekit-agents/livekit/agents/llm/fallback_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from .chat_context import ChatContext
from .function_context import CalledFunction, FunctionCallInfo, FunctionContext
from .llm import LLM, ChatChunk, LLMStream, ToolChoice
from .llm import LLM, ChatChunk, LLMCapabilities, LLMStream, ToolChoice

DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
Expand Down Expand Up @@ -45,7 +45,16 @@ def __init__(
if len(llm) < 1:
raise ValueError("at least one LLM instance must be provided.")

super().__init__()
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=all(
t.capabilities.supports_choices_on_int for t in llm
),
requires_persistent_functions=all(
t.capabilities.requires_persistent_functions for t in llm
),
)
)

self._llm_instances = llm
self._attempt_timeout = attempt_timeout
Expand Down
9 changes: 7 additions & 2 deletions livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Choice:
@dataclass
class LLMCapabilities:
supports_choices_on_int: bool = True
"""check whether the LLM supports integer enums choices as function arguments"""
requires_persistent_functions: bool = False
"""if the LLM requires function definition when previous function calls exist in chat context"""


@dataclass
Expand All @@ -73,9 +76,11 @@ class LLM(
rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]],
Generic[TEvent],
):
def __init__(self) -> None:
def __init__(self, *, capabilities: LLMCapabilities | None = None) -> None:
super().__init__()
self._capabilities = LLMCapabilities()
if capabilities is None:
capabilities = LLMCapabilities()
self._capabilities = capabilities
self._label = f"{type(self).__module__}.{type(self).__name__}"

@property
Expand Down
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ async def _execute_function_calls() -> None:
fnc_ctx
and new_speech_handle.fnc_nested_depth
>= self._opts.max_nested_fnc_calls
and not self._llm.capabilities.requires_persistent_functions
):
if len(fnc_ctx.ai_functions) > 1:
logger.info(
Expand All @@ -991,6 +992,7 @@ async def _execute_function_calls() -> None:
},
)
fnc_ctx = None

answer_llm_stream = self._llm.chat(
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
llm,
utils,
)
from livekit.agents.llm import ToolChoice
from livekit.agents.llm import LLMCapabilities, ToolChoice
from livekit.agents.llm.function_context import (
_create_ai_function_info,
_is_optional_type,
Expand Down Expand Up @@ -82,7 +82,13 @@ def __init__(
``api_key`` must be set to your Anthropic API key, either using the argument or by setting
the ``ANTHROPIC_API_KEY`` environmental variable.
"""
super().__init__()

super().__init__(
capabilities=LLMCapabilities(
requires_persistent_functions=True,
supports_choices_on_int=True,
)
)

# throw an error on our end
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
llm,
utils,
)
from livekit.agents.llm import ToolChoice, _create_ai_function_info
from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from google import genai
Expand Down Expand Up @@ -99,8 +99,12 @@ def __init__(
frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
"""
super().__init__()
self._capabilities = llm.LLMCapabilities(supports_choices_on_int=False)
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=False,
requires_persistent_functions=False,
)
)
self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
self._location = location or os.environ.get(
"GOOGLE_CLOUD_LOCATION", "us-central1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import httpx
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm import ToolChoice
from livekit.agents.llm import LLMCapabilities, ToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from openai import AsyncAssistantEventHandler, AsyncClient
Expand Down Expand Up @@ -99,7 +99,12 @@ def __init__(
base_url: str | None = None,
on_file_uploaded: OnFileUploaded | None = None,
) -> None:
super().__init__()
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=False,
)
)

test_ctx = llm.ChatContext()
if not hasattr(test_ctx, "_metadata"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
APITimeoutError,
llm,
)
from livekit.agents.llm import ToolChoice, _create_ai_function_info
from livekit.agents.llm import (
LLMCapabilities,
ToolChoice,
_create_ai_function_info,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

import openai
Expand Down Expand Up @@ -85,8 +89,12 @@ def __init__(
``api_key`` must be set to your OpenAI API key, either using the argument or by setting the
``OPENAI_API_KEY`` environmental variable.
"""
super().__init__()
self._capabilities = llm.LLMCapabilities(supports_choices_on_int=True)
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=False,
)
)

self._opts = LLMOptions(
model=model,
Expand Down
Loading