Skip to content

fix: enhance AzureOpenAIResponsesAPIConfig to support different Azure… #11027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 104 additions & 57 deletions litellm/llms/azure/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel

import litellm
from litellm._logging import verbose_logger
Expand Down Expand Up @@ -38,6 +39,20 @@ def __init__(
)


class AzureAuthResponse(BaseModel):
"""
Pydantic model representing the authentication response for Azure OpenAI.

Attributes:
api_key: The API key for Azure OpenAI, if provided. Can be None.
azure_ad_token_provider: A callable that provides an Azure AD token. Can be None.
azure_ad_token: The Azure AD token, if available. Can be None.
"""
api_key: Optional[str] = None
azure_ad_token_provider: Optional[Callable[[], str]] = None
azure_ad_token: Optional[str] = None


def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
Expand Down Expand Up @@ -259,6 +274,83 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
return azure_client_params


def get_azure_api_key_or_token(
litellm_params: dict,
api_key: Optional[str],
) -> AzureAuthResponse:
"""
Get Azure API key or token for authentication.

Args:
litellm_params: Dictionary containing parameters for LiteLLM.
api_key: Optional API key for Azure OpenAI.

Returns:
AzureAuthResponse: A Pydantic object containing the API key, Azure AD token provider,
and Azure AD token.
"""
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider")
# If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token")
tenant_id = litellm_params.get("tenant_id", os.getenv("AZURE_TENANT_ID"))
client_id = litellm_params.get("client_id", os.getenv("AZURE_CLIENT_ID"))
client_secret = litellm_params.get(
"client_secret", os.getenv("AZURE_CLIENT_SECRET")
)
azure_username = litellm_params.get(
"azure_username", os.getenv("AZURE_USERNAME")
)
azure_password = litellm_params.get(
"azure_password", os.getenv("AZURE_PASSWORD")
)
if (
not api_key
and azure_ad_token_provider is None
and tenant_id and client_id and client_secret
):
verbose_logger.debug(
"Using Azure AD Token Provider from Entra ID for Azure Auth"
)
azure_ad_token_provider = get_azure_ad_token_from_entra_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
)
if azure_ad_token_provider is None and azure_username and azure_password and client_id:
verbose_logger.debug("Using Azure Username and Password for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_username_password(
azure_username=azure_username,
azure_password=azure_password,
client_id=client_id,
)
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
verbose_logger.debug("Using Azure OIDC Token for Azure Auth")
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token=azure_ad_token,
azure_client_id=client_id,
azure_tenant_id=tenant_id,
)
elif (
not api_key
and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
verbose_logger.debug(
"Using Azure AD token provider based on Service Principal with Secret workflow for Azure Auth"
)
try:
azure_ad_token_provider = get_azure_ad_token_provider()
except ValueError:
verbose_logger.debug("Azure AD Token Provider could not be used.")

return AzureAuthResponse(
api_key=api_key,
azure_ad_token_provider=azure_ad_token_provider,
azure_ad_token=azure_ad_token
)



class BaseAzureLLM(BaseOpenAILLM):
def get_azure_openai_client(
self,
Expand Down Expand Up @@ -321,68 +413,21 @@ def initialize_azure_sdk_client(
api_version: Optional[str],
is_async: bool,
) -> dict:
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider")
# If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token")
tenant_id = litellm_params.get("tenant_id", os.getenv("AZURE_TENANT_ID"))
client_id = litellm_params.get("client_id", os.getenv("AZURE_CLIENT_ID"))
client_secret = litellm_params.get(
"client_secret", os.getenv("AZURE_CLIENT_SECRET")
)
azure_username = litellm_params.get(
"azure_username", os.getenv("AZURE_USERNAME")
)
azure_password = litellm_params.get(
"azure_password", os.getenv("AZURE_PASSWORD")
)
max_retries = litellm_params.get("max_retries")
timeout = litellm_params.get("timeout")
if (
not api_key
and azure_ad_token_provider is None
and tenant_id and client_id and client_secret
):
verbose_logger.debug(
"Using Azure AD Token Provider from Entra ID for Azure Auth"
)
azure_ad_token_provider = get_azure_ad_token_from_entra_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
)
if azure_ad_token_provider is None and azure_username and azure_password and client_id:
verbose_logger.debug("Using Azure Username and Password for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_username_password(
azure_username=azure_username,
azure_password=azure_password,
client_id=client_id,
)

if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
verbose_logger.debug("Using Azure OIDC Token for Azure Auth")
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token=azure_ad_token,
azure_client_id=client_id,
azure_tenant_id=tenant_id,
)
elif (
not api_key
and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
verbose_logger.debug(
"Using Azure AD token provider based on Service Principal with Secret workflow for Azure Auth"
)
try:
azure_ad_token_provider = get_azure_ad_token_provider()
except ValueError:
verbose_logger.debug("Azure AD Token Provider could not be used.")
if api_version is None:
api_version = os.getenv(
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
)

_api_key = api_key
auth_response = get_azure_api_key_or_token(
litellm_params=litellm_params,
api_key=api_key,
)
_api_key = auth_response.api_key
azure_ad_token_provider = auth_response.azure_ad_token_provider
azure_ad_token = auth_response.azure_ad_token
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
Expand All @@ -398,9 +443,11 @@ def initialize_azure_sdk_client(
}
# init http client + SSL Verification settings
if is_async is True:
azure_client_params["http_client"] = self._get_async_http_client()
# Type annotation to fix incompatible types error
azure_client_params["http_client"] = self._get_async_http_client() # type: ignore
else:
azure_client_params["http_client"] = self._get_sync_http_client()
# Type annotation to fix incompatible types error
azure_client_params["http_client"] = self._get_sync_http_client() # type: ignore

if max_retries is not None:
azure_client_params["max_retries"] = max_retries
Expand Down Expand Up @@ -450,7 +497,7 @@ def _init_azure_client_for_cloudflare_ai_gateway(
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
if isinstance(azure_ad_token, str) and azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token=azure_ad_token,
azure_client_id=client_id,
Expand Down
26 changes: 23 additions & 3 deletions litellm/llms/azure/responses/transformation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast

import httpx

import litellm
Expand All @@ -10,6 +9,9 @@
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import _add_path_to_api_base
from litellm.llms.azure.common_utils import (
get_azure_api_key_or_token
)

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
Expand All @@ -18,16 +20,34 @@
else:
LiteLLMLoggingObj = Any


class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
def validate_environment(
self,
headers: dict,
model: str,
litellm_params: dict,
api_key: Optional[str] = None,
) -> dict:

auth_response = get_azure_api_key_or_token(
litellm_params=litellm_params,
api_key=api_key,
)
api_key = auth_response.api_key
azure_ad_token_provider = auth_response.azure_ad_token_provider
azure_ad_token = auth_response.azure_ad_token

if azure_ad_token_provider:
azure_ad_token = azure_ad_token_provider()
elif azure_ad_token is not None:
if isinstance(azure_ad_token, str):
azure_ad_token = azure_ad_token
else:
azure_ad_token = azure_ad_token()

api_key = (
api_key
azure_ad_token
or api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/base_llm/responses/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def validate_environment(
self,
headers: dict,
model: str,
litellm_params: dict,
api_key: Optional[str] = None,
) -> dict:
return {}
Expand Down
6 changes: 6 additions & 0 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,7 @@ def response_api_handler(
api_key=litellm_params.api_key,
headers=response_api_optional_request_params.get("extra_headers", {}) or {},
model=model,
litellm_params=dict(litellm_params)
)

if extra_headers:
Expand Down Expand Up @@ -1393,6 +1394,7 @@ async def async_response_api_handler(
api_key=litellm_params.api_key,
headers=response_api_optional_request_params.get("extra_headers", {}) or {},
model=model,
litellm_params=dict(litellm_params)
)

if extra_headers:
Expand Down Expand Up @@ -1514,6 +1516,7 @@ async def async_delete_response_api_handler(
api_key=litellm_params.api_key,
headers=extra_headers or {},
model="None",
litellm_params=dict(litellm_params)
)

if extra_headers:
Expand Down Expand Up @@ -1598,6 +1601,7 @@ def delete_response_api_handler(
api_key=litellm_params.api_key,
headers=extra_headers or {},
model="None",
litellm_params=dict(litellm_params)
)

if extra_headers:
Expand Down Expand Up @@ -1683,6 +1687,7 @@ def get_responses(
api_key=litellm_params.api_key,
headers=extra_headers or {},
model="None",
litellm_params=dict(litellm_params)
)

if extra_headers:
Expand Down Expand Up @@ -1751,6 +1756,7 @@ async def async_get_responses(
api_key=litellm_params.api_key,
headers=extra_headers or {},
model="None",
litellm_params=dict(litellm_params)
)

if extra_headers:
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/openai/responses/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def validate_environment(
self,
headers: dict,
model: str,
litellm_params: dict,
api_key: Optional[str] = None,
) -> dict:
api_key = (
Expand Down
Loading