Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 156 additions & 16 deletions lib/crewai/src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import sys
import threading
import warnings
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -309,6 +310,7 @@ def writable(self) -> bool:
CONTEXT_WINDOW_USAGE_RATIO: Final[float] = 0.85
SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
"openai",
"openai_responses",
"anthropic",
"claude",
"azure",
Expand Down Expand Up @@ -346,23 +348,50 @@ class LLM(BaseLLM):
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM.

Supports both legacy provider parameter and new api parameter for clearer semantics.

Routing priority:
1. If 'provider' kwarg is present, use that provider with constants
1. If 'provider' kwarg is present, resolve provider + api combination
2. If only 'model' kwarg, use constants to infer provider
3. If "/" in model name:
- Check if prefix is a native provider (openai/anthropic/azure/bedrock/gemini)
- If yes, validate model against constants
- If valid, route to native SDK; otherwise route to LiteLLM

Args:
model: Model identifier (e.g., "gpt-4o", "claude-3-sonnet")
is_litellm: Force use of LiteLLM fallback
**kwargs: Additional parameters including:
- provider: Provider name ("openai", "anthropic", etc.)
- api: API type within provider ("chat", "responses", etc.)
- Other model parameters (temperature, etc.)
"""
if not model or not isinstance(model, str):
raise ValueError("Model must be a non-empty string")

explicit_provider = kwargs.get("provider")
explicit_api = kwargs.get("api")

# Resolve provider and API combination
if explicit_provider:
provider = explicit_provider
try:
resolved_provider = cls._resolve_provider_and_api(explicit_provider, explicit_api)
except ValueError as e:
raise ValueError(f"Invalid provider/api combination: {e}") from e

# Validate model compatibility with resolved provider
if not cls._validate_model_for_provider_api(model, resolved_provider, explicit_api):
supported_models = cls._get_supported_models_message(explicit_provider, explicit_api)
api_desc = f" with {explicit_api} API" if explicit_api else ""
raise ValueError(
f"Model '{model}' is not compatible with provider '{explicit_provider}'{api_desc}. "
f"Supported models: {supported_models}"
)

provider = resolved_provider
use_native = True
model_string = model

elif "/" in model:
prefix, _, model_part = model.partition("/")

Expand All @@ -378,18 +407,38 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"aws": "bedrock",
}

canonical_provider = provider_mapping.get(prefix.lower())

if canonical_provider and cls._validate_model_in_constants(
model_part, canonical_provider
):
provider = canonical_provider
if prefix.lower() == "openai_responses":
# Backwards-compatibility: allow the old prefix for now, but steer
# users to the explicit provider+api syntax.
if explicit_api not in (None, "responses"):
raise ValueError(
"Model prefix 'openai_responses/' implies api='responses' but "
f"got api='{explicit_api}'. Use provider='openai' with api='responses' instead."
)
warnings.warn(
"Model prefix 'openai_responses/' is deprecated and will be removed in a future version. "
"Use provider='openai' with api='responses' instead.",
DeprecationWarning,
stacklevel=2,
)
provider = "openai_responses"
use_native = True
model_string = model_part
canonical_provider = None
else:
provider = prefix
use_native = False
model_string = model_part
canonical_provider = provider_mapping.get(prefix.lower())

if canonical_provider is not None:
if canonical_provider and cls._validate_model_in_constants(
model_part, canonical_provider
):
provider = canonical_provider
use_native = True
model_string = model_part
else:
provider = prefix
use_native = False
model_string = model_part
else:
provider = cls._infer_provider_from_model(model)
use_native = True
Expand All @@ -398,8 +447,8 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
native_class = cls._get_native_provider(provider) if use_native else None
if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS:
try:
# Remove 'provider' from kwargs if it exists to avoid duplicate keyword argument
kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"}
# Remove 'provider' and 'api' from kwargs to avoid duplicate keyword arguments
kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("provider", "api")}
return cast(
Self,
native_class(model=model_string, provider=provider, **kwargs_copy),
Expand All @@ -419,6 +468,75 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
instance.is_litellm = True
return instance

@classmethod
def _resolve_provider_and_api(cls, provider: str, api: str | None) -> str:
"""Resolve logical provider + api combination to actual provider implementation.

Currently only handles OpenAI's multiple APIs to eliminate confusion with provider="openai_responses".
No fallback is kept for provider="openai_responses" since that alias is removed.

Args:
provider: Logical provider name (e.g., "openai")
api: API type (e.g., "responses", "chat", None)

Returns:
Actual provider implementation name
"""
if provider == "openai":
if api == "responses":
return "openai_responses"
elif api == "chat" or api is None:
return "openai"
else:
raise ValueError(f"Unsupported API '{api}' for provider 'openai'. Supported: 'chat', 'responses'")

# Explicitly disallow legacy alias to avoid confusion
if provider == "openai_responses":
# Backwards-compatibility: allow the legacy alias for now, but steer
# users to the explicit provider+api syntax.
if api not in (None, "responses"):
raise ValueError(
"provider='openai_responses' implies api='responses' but "
f"got api='{api}'. Use provider='openai' with api='responses' instead."
)
warnings.warn(
"provider='openai_responses' is deprecated and will be removed in a future version. "
"Use provider='openai' with api='responses' instead.",
DeprecationWarning,
stacklevel=2,
)
return "openai_responses"

return provider

@classmethod
def _validate_model_for_provider_api(cls, model: str, provider: str, api: str | None) -> bool:
"""Validate if a model is compatible with the given provider and API combination.

Currently focused on OpenAI API validation to prevent using incompatible models.

Args:
model: Model name to validate
provider: Resolved provider name
api: API type (for additional context)

Returns:
True if compatible, False otherwise
"""
return cls._matches_provider_pattern(model, provider)

@classmethod
def _get_supported_models_message(cls, provider: str, api: str | None) -> str:
"""Get a human-readable message about supported models for a provider/API combination."""
if provider in ["openai_responses", "openai"] and api == "responses":
return ("gpt-4o, gpt-4o-mini, gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, "
"o1, o1-mini, o1-preview, o3, o3-mini, o4-mini")
elif provider == "openai" or (provider == "openai" and api in [None, "chat"]):
return ("gpt-3.5-turbo, gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, "
"o1, o1-mini, o1-preview, o3, o3-mini, o4-mini, whisper-1, and other OpenAI models")
else:
return f"models supported by provider '{provider}'"

@classmethod
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
"""Check if a model name matches provider-specific patterns.
Expand All @@ -441,6 +559,13 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
)

if provider == "openai_responses":
# Responses API supports GPT-4o and o-series models
return any(
model_lower.startswith(prefix)
for prefix in ["gpt-4o", "gpt-4.1", "o1", "o3", "o4"]
)

if provider == "anthropic" or provider == "claude":
return any(
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
Expand Down Expand Up @@ -481,6 +606,15 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
if provider == "openai" and model in OPENAI_MODELS:
return True

if provider == "openai_responses":
# Responses API supports subset of OpenAI models
responses_models = {"gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano",
"o1", "o1-mini", "o1-preview", "o3", "o3-mini", "o4-mini"}
if model in responses_models:
return True
# Also check pattern matching for newer models with these prefixes
return cls._matches_provider_pattern(model, provider)

if (
provider == "anthropic" or provider == "claude"
) and model in ANTHROPIC_MODELS:
Expand All @@ -489,11 +623,10 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
return True

if provider == "bedrock" and model in BEDROCK_MODELS:
if (provider == "bedrock" or provider == "aws") and model in BEDROCK_MODELS:
return True

if provider == "azure":
# azure does not provide a list of available models, determine a better way to handle this
if provider == "azure" and model in AZURE_MODELS:
return True

# Fallback to pattern matching for models not in constants
Expand Down Expand Up @@ -538,6 +671,13 @@ def _get_native_provider(cls, provider: str) -> type | None:

return OpenAICompletion

if provider == "openai_responses":
from crewai.llms.providers.openai_responses.completion import (
OpenAIResponsesCompletion,
)

return OpenAIResponsesCompletion

if provider == "anthropic" or provider == "claude":
from crewai.llms.providers.anthropic.completion import (
AnthropicCompletion,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""OpenAI Responses API provider for CrewAI."""

from crewai.llms.providers.openai_responses.completion import (
OpenAIResponsesCompletion,
)

__all__ = ["OpenAIResponsesCompletion"]
Loading