Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions litellm/integrations/SlackAlerting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def process_slack_alerting_variables(

for alert_type, webhook_urls in alert_to_webhook_url.items():
if isinstance(webhook_urls, list):
# Optimize: use startswith instead of in for searching "os.environ/"
_webhook_values: List[str] = []
for webhook_url in webhook_urls:
if "os.environ/" in webhook_url:
if webhook_url.startswith("os.environ/"):
_env_value = get_secret(secret_name=webhook_url)
if not isinstance(_env_value, str):
raise ValueError(
Expand All @@ -44,7 +45,7 @@ def process_slack_alerting_variables(
alert_to_webhook_url[alert_type] = _webhook_values
else:
_webhook_value_str: str = webhook_urls
if "os.environ/" in webhook_urls:
if webhook_urls.startswith("os.environ/"):
_env_value = get_secret(secret_name=webhook_urls)
if not isinstance(_env_value, str):
raise ValueError(
Expand Down
96 changes: 37 additions & 59 deletions litellm/secret_managers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from litellm._logging import print_verbose, verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.secret_managers.get_azure_ad_token_provider import \
get_azure_ad_token_provider
from litellm.types.secret_managers.main import KeyManagementSystem

oidc_cache = DualCache()
Expand Down Expand Up @@ -98,22 +97,28 @@ def get_secret( # noqa: PLR0915
key_management_settings = litellm._key_management_settings
secret = None

# Optimize: Quick early-out if secret_name is os.environ/ and environment contains it
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")

# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
env_var = secret_name[11:]
secret = os.environ.get(env_var)
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
if secret_value_as_bool is not None and isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
return secret

# OIDC logic
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
oidc_aud = "/".join(secret_name_split.split("/")[1:])
secret_name_split = secret_name[5:]
first_slash = secret_name_split.find("/")
oidc_provider = secret_name_split[:first_slash]
oidc_aud = secret_name_split[first_slash+1:]
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token

oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
Expand All @@ -123,29 +128,24 @@ def get_secret( # noqa: PLR0915
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
raise ValueError("Google OIDC provider failed")
elif oidc_provider == "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
elif oidc_provider == "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
elif oidc_provider == "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if actions_id_token_request_url is None or actions_id_token_request_token is None:
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)

oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
Expand All @@ -163,10 +163,8 @@ def get_secret( # noqa: PLR0915
oidc_token = response.json().get("value", None)
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
verbose_logger.warning(
Expand All @@ -183,29 +181,22 @@ def get_secret( # noqa: PLR0915
verbose_logger.error(error_msg)
raise ValueError(error_msg)
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
return f.read()
elif oidc_provider == "file":
# Load token from a file
with open(oidc_aud, "r") as f:
oidc_token = f.read()
return oidc_token
return f.read()
elif oidc_provider == "env":
# Load token directly from an environment variable
oidc_token = os.getenv(oidc_aud)
if oidc_token is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
return oidc_token
elif oidc_provider == "env_path":
# Load token from a file path specified in an environment variable
token_file_path = os.getenv(oidc_aud)
if token_file_path is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
with open(token_file_path, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")
return f.read()
raise ValueError("Unsupported OIDC provider")

try:
if _should_read_secret_from_secret_manager() and litellm.secret_manager_client is not None:
Expand All @@ -216,17 +207,18 @@ def get_secret( # noqa: PLR0915
key_manager = key_management_system.value

if key_management_settings is not None:
hosted_keys = key_management_settings.hosted_keys
if (
key_management_settings.hosted_keys is not None
and secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
hosted_keys is not None
and secret_name not in hosted_keys
):
key_manager = "local"

if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
):
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS.value
Expand All @@ -236,44 +228,33 @@ def get_secret( # noqa: PLR0915
if encrypted_secret is None:
raise ValueError("Google KMS requires the encrypted secret to be in the environment!")
b64_flag = _is_base64(encrypted_secret)
if b64_flag is True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
ciphertext = encrypted_secret
if b64_flag:
ciphertext = base64.b64decode(encrypted_secret)
else:
raise ValueError(
"Google KMS requires the encrypted secret to be encoded in base64"
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
)
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode("utf-8") # assumes the original value was encoded with utf-8
secret = response.plaintext.decode("utf-8")
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception("AWS KMS - Encrypted Value of Key={} is None".format(secret_name))
# Decode the base64 encoded ciphertext
raise Exception(f"AWS KMS - Encrypted Value of Key={secret_name} is None")
ciphertext_blob = base64.b64decode(encrypted_value)

# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)

# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
from litellm.secret_managers.aws_secret_manager_v2 import \
AWSSecretsManagerV2

if isinstance(client, AWSSecretsManagerV2):
secret = client.sync_read_secret(
Expand Down Expand Up @@ -302,7 +283,7 @@ def get_secret( # noqa: PLR0915
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
except Exception as e:
verbose_logger.error(
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
)
Expand All @@ -312,22 +293,19 @@ def get_secret( # noqa: PLR0915
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
return secret
except Exception:
return secret
else:
secret = os.environ.get(secret_name)
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
if secret_value_as_bool is not None and isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
return secret
except Exception as e:
if default_value is not None:
return default_value
else:
raise e
raise e


def _should_read_secret_from_secret_manager() -> bool:
Expand Down