-
Notifications
You must be signed in to change notification settings - Fork 1.3k
RFC: automatically use litellm if possible #534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
|
||
from agents import Agent, Runner, function_tool, set_tracing_disabled | ||
|
||
"""This example uses the built-in support for LiteLLM. To use this, ensure you have the | ||
ANTHROPIC_API_KEY environment variable set. | ||
""" | ||
|
||
set_tracing_disabled(disabled=True) | ||
|
||
|
||
@function_tool | ||
def get_weather(city: str): | ||
print(f"[debug] getting weather for {city}") | ||
return f"The weather in {city} is sunny." | ||
|
||
|
||
async def main(): | ||
agent = Agent( | ||
name="Assistant", | ||
instructions="You only respond in haikus.", | ||
# We prefix with litellm/ to tell the Runner to use the LitellmModel | ||
model="litellm/anthropic/claude-3-5-sonnet-20240620", | ||
tools=[get_weather], | ||
) | ||
|
||
result = await Runner.run(agent, "What's the weather in Tokyo?") | ||
print(result.final_output) | ||
|
||
|
||
if __name__ == "__main__": | ||
import os | ||
|
||
if os.getenv("ANTHROPIC_API_KEY") is None: | ||
raise ValueError( | ||
"ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again." | ||
) | ||
|
||
asyncio.run(main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from ...models.interface import Model, ModelProvider | ||
from .litellm_model import LitellmModel | ||
|
||
DEFAULT_MODEL: str = "gpt-4.1" | ||
|
||
|
||
class LitellmProvider(ModelProvider): | ||
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via: | ||
```python | ||
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider())) | ||
``` | ||
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers). | ||
NOTE: API keys must be set via environment variables. If you're using models that require | ||
additional configuration (e.g. Azure API base or version), those must also be set via the | ||
environment variables that LiteLLM expects. If you have more advanced needs, we recommend | ||
copy-pasting this class and making any modifications you need. | ||
""" | ||
|
||
def get_model(self, model_name: str | None) -> Model: | ||
return LitellmModel(model_name or DEFAULT_MODEL) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from __future__ import annotations | ||
|
||
from openai import AsyncOpenAI | ||
|
||
from ..exceptions import UserError | ||
from .interface import Model, ModelProvider | ||
from .openai_provider import OpenAIProvider | ||
|
||
|
||
class MultiProviderMap: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this class to encapsulate the simple operations on plain dict? |
||
"""A map of model name prefixes to ModelProviders.""" | ||
|
||
def __init__(self): | ||
self._mapping: dict[str, ModelProvider] = {} | ||
|
||
def has_prefix(self, prefix: str) -> bool: | ||
"""Returns True if the given prefix is in the mapping.""" | ||
return prefix in self._mapping | ||
|
||
def get_mapping(self) -> dict[str, ModelProvider]: | ||
"""Returns a copy of the current prefix -> ModelProvider mapping.""" | ||
return self._mapping.copy() | ||
|
||
def set_mapping(self, mapping: dict[str, ModelProvider]): | ||
"""Overwrites the current mapping with a new one.""" | ||
self._mapping = mapping | ||
|
||
def get_provider(self, prefix: str) -> ModelProvider | None: | ||
"""Returns the ModelProvider for the given prefix. | ||
|
||
Args: | ||
prefix: The prefix of the model name e.g. "openai" or "my_prefix". | ||
""" | ||
return self._mapping.get(prefix) | ||
|
||
def add_provider(self, prefix: str, provider: ModelProvider): | ||
"""Adds a new prefix -> ModelProvider mapping. | ||
|
||
Args: | ||
prefix: The prefix of the model name e.g. "openai" or "my_prefix". | ||
provider: The ModelProvider to use for the given prefix. | ||
""" | ||
self._mapping[prefix] = provider | ||
|
||
def remove_provider(self, prefix: str): | ||
"""Removes the mapping for the given prefix. | ||
|
||
Args: | ||
prefix: The prefix of the model name e.g. "openai" or "my_prefix". | ||
""" | ||
del self._mapping[prefix] | ||
|
||
|
||
class MultiProvider(ModelProvider): | ||
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the | ||
mapping is: | ||
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1" | ||
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1" | ||
|
||
You can override or customize this mapping. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
provider_map: MultiProviderMap | None = None, | ||
openai_api_key: str | None = None, | ||
openai_base_url: str | None = None, | ||
openai_client: AsyncOpenAI | None = None, | ||
openai_organization: str | None = None, | ||
openai_project: str | None = None, | ||
openai_use_responses: bool | None = None, | ||
) -> None: | ||
"""Create a new OpenAI provider. | ||
|
||
Args: | ||
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided, | ||
we will use a default mapping. See the documentation for this class to see the | ||
default mapping. | ||
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use | ||
the default API key. | ||
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will | ||
use the default base URL. | ||
openai_client: An optional OpenAI client to use. If not provided, we will create a new | ||
OpenAI client using the api_key and base_url. | ||
openai_organization: The organization to use for the OpenAI provider. | ||
openai_project: The project to use for the OpenAI provider. | ||
openai_use_responses: Whether to use the OpenAI responses API. | ||
""" | ||
self.provider_map = provider_map | ||
self.openai_provider = OpenAIProvider( | ||
api_key=openai_api_key, | ||
base_url=openai_base_url, | ||
openai_client=openai_client, | ||
organization=openai_organization, | ||
project=openai_project, | ||
use_responses=openai_use_responses, | ||
) | ||
|
||
self._fallback_providers: dict[str, ModelProvider] = {} | ||
|
||
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]: | ||
if model_name is None: | ||
return None, None | ||
elif "/" in model_name: | ||
prefix, model_name = model_name.split("/", 1) | ||
return prefix, model_name | ||
else: | ||
return None, model_name | ||
|
||
def _create_fallback_provider(self, prefix: str) -> ModelProvider: | ||
if prefix == "litellm": | ||
from ..extensions.models.litellm_provider import LitellmProvider | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this dynamic import here! nice-to-have: on the LitellmProvider side, having try/except clause for loading litellm module and raising a more user-friendly error message than exposing the missing litellm could make dev experience better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @seratch - LitellmProvider will try to import LitellmModel, which will display the nice error message: https://github.com/openai/openai-agents-python/blob/main/src/agents/extensions/models/litellm_model.py#L16-L19 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, you’re already ahead! |
||
|
||
return LitellmProvider() | ||
else: | ||
raise UserError(f"Unknown prefix: {prefix}") | ||
|
||
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider: | ||
if prefix is None or prefix == "openai": | ||
return self.openai_provider | ||
elif prefix in self._fallback_providers: | ||
return self._fallback_providers[prefix] | ||
else: | ||
self._fallback_providers[prefix] = self._create_fallback_provider(prefix) | ||
return self._fallback_providers[prefix] | ||
|
||
def get_model(self, model_name: str | None) -> Model: | ||
"""Returns a Model based on the model name. The model name can have a prefix, ending with | ||
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use | ||
the OpenAI provider. | ||
|
||
Args: | ||
model_name: The name of the model to get. | ||
|
||
Returns: | ||
A Model. | ||
""" | ||
prefix, model_name = self._get_prefix_and_model_name(model_name) | ||
|
||
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)): | ||
return provider.get_model(model_name) | ||
else: | ||
return self._get_fallback_provider(prefix).get_model(model_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would importing https://github.com/openai/openai-agents-python/blob/v0.0.11/src/agents/models/openai_provider.py#L11 instead be better? Also, huge 👍 to switching from gtp-4o to gpt-4.1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like a minor breaking change. Though probably fine! Good call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rm-openai Perhaps you're already aware of this, but switching to gpt-4.1 might break existing CUA apps (the tool is not yet available with 4.1 while web_search_preview works with 4.1), so indeed switching the default model could be a breaking change