Skip to content

Commit dca6904

Browse files
JWT Auth - enforce_rbac support + UI team view, spend calc fix (#7863)
* fix(user_dashboard.tsx): fix spend calculation when team selected sum all team keys, not user keys * docs(admin_ui_sso.md): fix docs tabbing * feat(user_api_key_auth.py): introduce new 'enforce_rbac' param on jwt auth allows proxy admin to prevent any unmapped yet authenticated jwt tokens from calling proxy Fixes #6793 * test: more unit testing + refactoring * fix: fix returning id when obj not found in db * fix(user_api_key_auth.py): add end user id tracking from jwt auth * docs(token_auth.md): add doc on rbac with JWTs * fix: fix unused params * test: remove old test
1 parent c306c2e commit dca6904

File tree

12 files changed

+447
-195
lines changed

12 files changed

+447
-195
lines changed

docs/my-website/docs/proxy/admin_ui_sso.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import Image from '@theme/IdealImage';
2+
import Tabs from '@theme/Tabs';
3+
import TabItem from '@theme/TabItem';
4+
15
# ✨ SSO for Admin UI
26

37
:::info

docs/my-website/docs/proxy/token_auth.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ general_settings:
114114
admin_jwt_scope: "litellm-proxy-admin"
115115
```
116116

117-
## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
117+
## Tracking End-Users / Internal Users / Team / Org
118118

119119
Set the field in the jwt token, which corresponds to a litellm user / team / org.
120120

@@ -156,6 +156,33 @@ scope: ["litellm-proxy-admin",...]
156156
scope: "litellm-proxy-admin ..."
157157
```
158158
159+
## Enforce Role-Based Access Control (RBAC)
160+
161+
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
162+
163+
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
164+
165+
```yaml
166+
general_settings:
167+
master_key: sk-1234
168+
enable_jwt_auth: True
169+
litellm_jwtauth:
170+
admin_jwt_scope: "litellm_proxy_endpoints_access"
171+
admin_allowed_routes:
172+
- openai_routes
173+
- info_routes
174+
public_key_ttl: 600
175+
enforce_rbac: true # 👈 Enforce RBAC
176+
```
177+
178+
Expected Scope in JWT:
179+
180+
```
181+
{
182+
"scope": "litellm_proxy_endpoints_access"
183+
}
184+
```
185+
159186
## Advanced - Allowed Routes
160187

161188
Configure which routes a JWT can access via the config.

litellm/model_prices_and_context_window_backup.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6587,13 +6587,41 @@
65876587
"litellm_provider": "bedrock",
65886588
"mode": "image_generation"
65896589
},
6590+
"stability.sd3-5-large-v1:0": {
6591+
"max_tokens": 77,
6592+
"max_input_tokens": 77,
6593+
"output_cost_per_image": 0.08,
6594+
"litellm_provider": "bedrock",
6595+
"mode": "image_generation"
6596+
},
6597+
"stability.stable-image-core-v1:0": {
6598+
"max_tokens": 77,
6599+
"max_input_tokens": 77,
6600+
"output_cost_per_image": 0.04,
6601+
"litellm_provider": "bedrock",
6602+
"mode": "image_generation"
6603+
},
6604+
"stability.stable-image-core-v1:1": {
6605+
"max_tokens": 77,
6606+
"max_input_tokens": 77,
6607+
"output_cost_per_image": 0.04,
6608+
"litellm_provider": "bedrock",
6609+
"mode": "image_generation"
6610+
},
65906611
"stability.stable-image-ultra-v1:0": {
65916612
"max_tokens": 77,
65926613
"max_input_tokens": 77,
65936614
"output_cost_per_image": 0.14,
65946615
"litellm_provider": "bedrock",
65956616
"mode": "image_generation"
65966617
},
6618+
"stability.stable-image-ultra-v1:1": {
6619+
"max_tokens": 77,
6620+
"max_input_tokens": 77,
6621+
"output_cost_per_image": 0.14,
6622+
"litellm_provider": "bedrock",
6623+
"mode": "image_generation"
6624+
},
65976625
"sagemaker/meta-textgeneration-llama-2-7b": {
65986626
"max_tokens": 4096,
65996627
"max_input_tokens": 4096,

litellm/proxy/_types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
416416
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
417417
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
418418
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
419+
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
420+
- enforce_rbac: If true, enforce RBAC for all routes.
419421
420422
See `auth_checks.py` for the specific routes
421423
"""
@@ -446,6 +448,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
446448
)
447449
end_user_id_jwt_field: Optional[str] = None
448450
public_key_ttl: float = 600
451+
public_allowed_routes: List[str] = ["public_routes"]
452+
enforce_rbac: bool = False
449453

450454
def __init__(self, **kwargs: Any) -> None:
451455
# get the attribute names for this Pydantic model
@@ -2284,6 +2288,19 @@ class ProxyStateVariables(TypedDict):
22842288
UI_TEAM_ID = "litellm-dashboard"
22852289

22862290

2291+
2292+
class JWTAuthBuilderResult(TypedDict):
2293+
is_proxy_admin: bool
2294+
team_object: Optional[LiteLLM_TeamTable]
2295+
user_object: Optional[LiteLLM_UserTable]
2296+
end_user_object: Optional[LiteLLM_EndUserTable]
2297+
org_object: Optional[LiteLLM_OrganizationTable]
2298+
token: str
2299+
team_id: Optional[str]
2300+
user_id: Optional[str]
2301+
end_user_id: Optional[str]
2302+
org_id: Optional[str]
2303+
22872304
class ClientSideFallbackModel(TypedDict, total=False):
22882305
"""
22892306
Dictionary passed when client configuring input

litellm/proxy/auth/auth_checks.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
1010
"""
1111

12-
import inspect
1312
import time
1413
import traceback
1514
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@@ -22,7 +21,6 @@
2221
from litellm.caching.dual_cache import LimitedSizeOrderedDict
2322
from litellm.proxy._types import (
2423
DB_CONNECTION_ERROR_TYPES,
25-
CommonProxyErrors,
2624
LiteLLM_EndUserTable,
2725
LiteLLM_JWTAuth,
2826
LiteLLM_OrganizationTable,
@@ -55,33 +53,6 @@
5553
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
5654

5755

58-
def _allowed_import_check() -> bool:
59-
from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder
60-
61-
# Get the calling frame
62-
caller_frame = inspect.stack()[2]
63-
caller_function = caller_frame.function
64-
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
65-
66-
allowed_function = "_user_api_key_auth_builder"
67-
allowed_signature = inspect.signature(_user_api_key_auth_builder)
68-
if caller_function_callable is None or not callable(caller_function_callable):
69-
raise Exception(f"Caller function {caller_function} is not callable")
70-
caller_signature = inspect.signature(caller_function_callable)
71-
72-
if caller_signature != allowed_signature:
73-
raise TypeError(
74-
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
75-
)
76-
# Check if the caller module is allowed
77-
if caller_function != allowed_function:
78-
raise ImportError(
79-
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
80-
)
81-
82-
return True
83-
84-
8556
def common_checks( # noqa: PLR0915
8657
request_body: dict,
8758
team_object: Optional[LiteLLM_TeamTable],
@@ -106,7 +77,6 @@ def common_checks( # noqa: PLR0915
10677
9. Check if request body is safe
10778
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
10879
"""
109-
_allowed_import_check()
11080
_model = request_body.get("model", None)
11181
if team_object is not None and team_object.blocked is True:
11282
raise Exception(
@@ -844,7 +814,7 @@ async def get_org_object(
844814
user_api_key_cache: DualCache,
845815
parent_otel_span: Optional[Span] = None,
846816
proxy_logging_obj: Optional[ProxyLogging] = None,
847-
):
817+
) -> Optional[LiteLLM_OrganizationTable]:
848818
"""
849819
- Check if org id in proxy Org Table
850820
- if valid, return LiteLLM_OrganizationTable object
@@ -859,7 +829,7 @@ async def get_org_object(
859829
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
860830
if cached_org_obj is not None:
861831
if isinstance(cached_org_obj, dict):
862-
return cached_org_obj
832+
return LiteLLM_OrganizationTable(**cached_org_obj)
863833
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
864834
return cached_org_obj
865835
# else, check db

litellm/proxy/auth/handle_jwt.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from litellm._logging import verbose_proxy_logger
1818
from litellm.caching.caching import DualCache
1919
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
20-
from litellm.proxy._types import JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth
20+
from litellm.proxy._types import (
21+
JWKKeyValue,
22+
JWTKeyItem,
23+
LiteLLM_JWTAuth,
24+
LitellmUserRoles,
25+
)
2126
from litellm.proxy.utils import PrismaClient
2227

2328

@@ -54,6 +59,34 @@ def is_jwt(self, token: str):
5459
parts = token.split(".")
5560
return len(parts) == 3
5661

62+
def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
63+
"""
64+
Returns the RBAC role the token 'belongs' to.
65+
66+
RBAC roles allowed to make requests:
67+
- PROXY_ADMIN: can make requests to all routes
68+
- TEAM: can make requests to routes associated with a team
69+
- INTERNAL_USER: can make requests to routes associated with a user
70+
71+
Resolves: https://github.com/BerriAI/litellm/issues/6793
72+
73+
Returns:
74+
- PROXY_ADMIN: if token is admin
75+
- TEAM: if token is associated with a team
76+
- INTERNAL_USER: if token is associated with a user
77+
- None: if token is not associated with a team or user
78+
"""
79+
scopes = self.get_scopes(token=token)
80+
is_admin = self.is_admin(scopes=scopes)
81+
if is_admin:
82+
return LitellmUserRoles.PROXY_ADMIN
83+
elif self.get_team_id(token=token, default_value=None) is not None:
84+
return LitellmUserRoles.TEAM
85+
elif self.get_user_id(token=token, default_value=None) is not None:
86+
return LitellmUserRoles.INTERNAL_USER
87+
88+
return None
89+
5790
def is_admin(self, scopes: list) -> bool:
5891
if self.litellm_jwtauth.admin_jwt_scope in scopes:
5992
return True
@@ -68,12 +101,14 @@ def get_end_user_id(
68101
self, token: dict, default_value: Optional[str]
69102
) -> Optional[str]:
70103
try:
104+
71105
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
72106
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
73107
else:
74108
user_id = None
75109
except KeyError:
76110
user_id = default_value
111+
77112
return user_id
78113

79114
def is_required_team_id(self) -> bool:
@@ -169,6 +204,7 @@ def get_scopes(self, token: dict) -> list:
169204
return scopes
170205

171206
async def get_public_key(self, kid: Optional[str]) -> dict:
207+
172208
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
173209

174210
if keys_url is None:

0 commit comments

Comments
 (0)