1
- from typing import Annotated , Union , Any
1
+ from typing import Annotated , Any , Union
2
2
3
+ from fastapi import Depends , HTTPException , Request
3
4
from sqlalchemy .ext .asyncio import AsyncSession
4
- from fastapi import (
5
- Depends ,
6
- HTTPException ,
7
- Request
8
- )
9
5
10
- from ..core .security import oauth2_scheme
11
6
from ..core .config import settings
12
- from ..core .exceptions .http_exceptions import UnauthorizedException , ForbiddenException , RateLimitException
13
7
from ..core .db .database import async_get_db
8
+ from ..core .exceptions .http_exceptions import ForbiddenException , RateLimitException , UnauthorizedException
14
9
from ..core .logger import logging
10
+ from ..core .security import oauth2_scheme , verify_token
15
11
from ..core .utils .rate_limit import is_rate_limited
16
- from ..core .security import verify_token
17
12
from ..crud .crud_rate_limit import crud_rate_limits
18
13
from ..crud .crud_tier import crud_tiers
19
14
from ..crud .crud_users import crud_users
25
20
DEFAULT_LIMIT = settings .DEFAULT_RATE_LIMIT_LIMIT
26
21
DEFAULT_PERIOD = settings .DEFAULT_RATE_LIMIT_PERIOD
27
22
23
+
28
24
async def get_current_user (
29
- token : Annotated [str , Depends (oauth2_scheme )],
30
- db : Annotated [AsyncSession , Depends (async_get_db )]
25
+ token : Annotated [str , Depends (oauth2_scheme )], db : Annotated [AsyncSession , Depends (async_get_db )]
31
26
) -> Union [dict [str , Any ], None ]:
32
27
token_data = await verify_token (token , db )
33
28
if token_data is None :
34
29
raise UnauthorizedException ("User not authenticated." )
35
30
36
31
if "@" in token_data .username_or_email :
37
32
user : dict | None = await crud_users .get (db = db , email = token_data .username_or_email , is_deleted = False )
38
- else :
33
+ else :
39
34
user = await crud_users .get (db = db , username = token_data .username_or_email , is_deleted = False )
40
-
35
+
41
36
if user :
42
37
return user
43
38
44
39
raise UnauthorizedException ("User not authenticated." )
45
40
46
41
47
- async def get_optional_user (
48
- request : Request ,
49
- db : AsyncSession = Depends (async_get_db )
50
- ) -> dict | None :
42
+ async def get_optional_user (request : Request , db : AsyncSession = Depends (async_get_db )) -> dict | None :
51
43
token = request .headers .get ("Authorization" )
52
44
if not token :
53
45
return None
54
46
55
47
try :
56
- token_type , _ , token_value = token .partition (' ' )
57
- if token_type .lower () != ' bearer' or not token_value :
48
+ token_type , _ , token_value = token .partition (" " )
49
+ if token_type .lower () != " bearer" or not token_value :
58
50
return None
59
51
60
52
token_data = await verify_token (token_value , db )
61
53
if token_data is None :
62
54
return None
63
55
64
56
return await get_current_user (token_value , db = db )
65
-
57
+
66
58
except HTTPException as http_exc :
67
59
if http_exc .status_code != 401 :
68
60
logger .error (f"Unexpected HTTPException in get_optional_user: { http_exc .detail } " )
69
61
return None
70
-
62
+
71
63
except Exception as exc :
72
64
logger .error (f"Unexpected error in get_optional_user: { exc } " )
73
65
return None
@@ -76,29 +68,26 @@ async def get_optional_user(
76
68
async def get_current_superuser (current_user : Annotated [dict , Depends (get_current_user )]) -> dict :
77
69
if not current_user ["is_superuser" ]:
78
70
raise ForbiddenException ("You do not have enough privileges." )
79
-
71
+
80
72
return current_user
81
73
82
74
83
75
async def rate_limiter (
84
- request : Request ,
85
- db : Annotated [AsyncSession , Depends (async_get_db )],
86
- user : User | None = Depends (get_optional_user )
76
+ request : Request , db : Annotated [AsyncSession , Depends (async_get_db )], user : User | None = Depends (get_optional_user )
87
77
) -> None :
88
78
path = sanitize_path (request .url .path )
89
79
if user :
90
80
user_id = user ["id" ]
91
81
tier = await crud_tiers .get (db , id = user ["tier_id" ])
92
82
if tier :
93
- rate_limit = await crud_rate_limits .get (
94
- db = db ,
95
- tier_id = tier ["id" ],
96
- path = path
97
- )
83
+ rate_limit = await crud_rate_limits .get (db = db , tier_id = tier ["id" ], path = path )
98
84
if rate_limit :
99
85
limit , period = rate_limit ["limit" ], rate_limit ["period" ]
100
86
else :
101
- logger .warning (f"User { user_id } with tier '{ tier ['name' ]} ' has no specific rate limit for path '{ path } '. Applying default rate limit." )
87
+ logger .warning (
88
+ f"User { user_id } with tier '{ tier ['name' ]} ' has no specific rate limit for path '{ path } '. \
89
+ Applying default rate limit."
90
+ )
102
91
limit , period = DEFAULT_LIMIT , DEFAULT_PERIOD
103
92
else :
104
93
logger .warning (f"User { user_id } has no assigned tier. Applying default rate limit." )
@@ -107,12 +96,6 @@ async def rate_limiter(
107
96
user_id = request .client .host
108
97
limit , period = DEFAULT_LIMIT , DEFAULT_PERIOD
109
98
110
- is_limited = await is_rate_limited (
111
- db = db ,
112
- user_id = user_id ,
113
- path = path ,
114
- limit = limit ,
115
- period = period
116
- )
99
+ is_limited = await is_rate_limited (db = db , user_id = user_id , path = path , limit = limit , period = period )
117
100
if is_limited :
118
101
raise RateLimitException ("Rate limit exceeded." )
0 commit comments