Skip to content

RFC: automatically use litellm if possible #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/model_providers/litellm_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import asyncio

from agents import Agent, Runner, function_tool, set_tracing_disabled

"""This example uses the built-in support for LiteLLM. To use this, ensure you have the
ANTHROPIC_API_KEY environment variable set.
"""

set_tracing_disabled(disabled=True)


@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."


async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
# We prefix with litellm/ to tell the Runner to use the LitellmModel
model="litellm/anthropic/claude-3-5-sonnet-20240620",
tools=[get_weather],
)

result = await Runner.run(agent, "What's the weather in Tokyo?")
print(result.final_output)


if __name__ == "__main__":
import os

if os.getenv("ANTHROPIC_API_KEY") is None:
raise ValueError(
"ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again."
)

asyncio.run(main())
21 changes: 21 additions & 0 deletions src/agents/extensions/models/litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from ...models.interface import Model, ModelProvider
from .litellm_model import LitellmModel

DEFAULT_MODEL: str = "gpt-4.1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would importing https://github.com/openai/openai-agents-python/blob/v0.0.11/src/agents/models/openai_provider.py#L11 instead be better? Also, huge 👍 to switching from gtp-4o to gpt-4.1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like a minor breaking change. Though probably fine! Good call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rm-openai Perhaps you're already aware of this, but switching to gpt-4.1 might break existing CUA apps (the tool is not yet available with 4.1 while web_search_preview works with 4.1), so indeed switching the default model could be a breaking change



class LitellmProvider(ModelProvider):
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
```python
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
```
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
NOTE: API keys must be set via environment variables. If you're using models that require
additional configuration (e.g. Azure API base or version), those must also be set via the
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
copy-pasting this class and making any modifications you need.
"""

def get_model(self, model_name: str | None) -> Model:
return LitellmModel(model_name or DEFAULT_MODEL)
144 changes: 144 additions & 0 deletions src/agents/models/multi_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from __future__ import annotations

from openai import AsyncOpenAI

from ..exceptions import UserError
from .interface import Model, ModelProvider
from .openai_provider import OpenAIProvider


class MultiProviderMap:
Copy link

@yihuang yihuang Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this class to encapsulate the simple operations on plain dict?

"""A map of model name prefixes to ModelProviders."""

def __init__(self):
self._mapping: dict[str, ModelProvider] = {}

def has_prefix(self, prefix: str) -> bool:
"""Returns True if the given prefix is in the mapping."""
return prefix in self._mapping

def get_mapping(self) -> dict[str, ModelProvider]:
"""Returns a copy of the current prefix -> ModelProvider mapping."""
return self._mapping.copy()

def set_mapping(self, mapping: dict[str, ModelProvider]):
"""Overwrites the current mapping with a new one."""
self._mapping = mapping

def get_provider(self, prefix: str) -> ModelProvider | None:
"""Returns the ModelProvider for the given prefix.

Args:
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
"""
return self._mapping.get(prefix)

def add_provider(self, prefix: str, provider: ModelProvider):
"""Adds a new prefix -> ModelProvider mapping.

Args:
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
provider: The ModelProvider to use for the given prefix.
"""
self._mapping[prefix] = provider

def remove_provider(self, prefix: str):
"""Removes the mapping for the given prefix.

Args:
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
"""
del self._mapping[prefix]


class MultiProvider(ModelProvider):
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the
mapping is:
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"

You can override or customize this mapping.
"""

def __init__(
self,
*,
provider_map: MultiProviderMap | None = None,
openai_api_key: str | None = None,
openai_base_url: str | None = None,
openai_client: AsyncOpenAI | None = None,
openai_organization: str | None = None,
openai_project: str | None = None,
openai_use_responses: bool | None = None,
) -> None:
"""Create a new OpenAI provider.

Args:
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
we will use a default mapping. See the documentation for this class to see the
default mapping.
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
the default API key.
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
use the default base URL.
openai_client: An optional OpenAI client to use. If not provided, we will create a new
OpenAI client using the api_key and base_url.
openai_organization: The organization to use for the OpenAI provider.
openai_project: The project to use for the OpenAI provider.
openai_use_responses: Whether to use the OpenAI responses API.
"""
self.provider_map = provider_map
self.openai_provider = OpenAIProvider(
api_key=openai_api_key,
base_url=openai_base_url,
openai_client=openai_client,
organization=openai_organization,
project=openai_project,
use_responses=openai_use_responses,
)

self._fallback_providers: dict[str, ModelProvider] = {}

def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
if model_name is None:
return None, None
elif "/" in model_name:
prefix, model_name = model_name.split("/", 1)
return prefix, model_name
else:
return None, model_name

def _create_fallback_provider(self, prefix: str) -> ModelProvider:
if prefix == "litellm":
from ..extensions.models.litellm_provider import LitellmProvider
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this dynamic import here!

nice-to-have: on the LitellmProvider side, having try/except clause for loading litellm module and raising a more user-friendly error message than exposing the missing litellm could make dev experience better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seratch - LitellmProvider will try to import LitellmModel, which will display the nice error message: https://github.com/openai/openai-agents-python/blob/main/src/agents/extensions/models/litellm_model.py#L16-L19

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, you’re already ahead!


return LitellmProvider()
else:
raise UserError(f"Unknown prefix: {prefix}")

def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
if prefix is None or prefix == "openai":
return self.openai_provider
elif prefix in self._fallback_providers:
return self._fallback_providers[prefix]
else:
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
return self._fallback_providers[prefix]

def get_model(self, model_name: str | None) -> Model:
"""Returns a Model based on the model name. The model name can have a prefix, ending with
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
the OpenAI provider.

Args:
model_name: The name of the model to get.

Returns:
A Model.
"""
prefix, model_name = self._get_prefix_and_model_name(model_name)

if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
return provider.get_model(model_name)
else:
return self._get_fallback_provider(prefix).get_model(model_name)
4 changes: 2 additions & 2 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .logger import logger
from .model_settings import ModelSettings
from .models.interface import Model, ModelProvider
from .models.openai_provider import OpenAIProvider
from .models.multi_provider import MultiProvider
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
Expand All @@ -56,7 +56,7 @@ class RunConfig:
agent. The model_provider passed in below must be able to resolve this model name.
"""

model_provider: ModelProvider = field(default_factory=OpenAIProvider)
model_provider: ModelProvider = field(default_factory=MultiProvider)
"""The model provider to use when looking up string model names. Defaults to OpenAI."""

model_settings: ModelSettings | None = None
Expand Down