Skip to content

Commit 3d92584

Browse files
committed
add typed dicts for convenience
1 parent 8683823 commit 3d92584

File tree

5 files changed

+101
-79
lines changed

5 files changed

+101
-79
lines changed

supabase_auth/_async/gotrue_client.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ..http_clients import AsyncClient
4444
from ..timer import Timer
4545
from ..types import (
46+
JWK,
4647
AuthChangeEvent,
4748
AuthenticatorAssuranceLevels,
4849
AuthFlowType,
@@ -58,6 +59,7 @@
5859
CodeExchangeParams,
5960
DecodedJWTDict,
6061
IdentitiesResponse,
62+
JWKSet,
6163
MFAChallengeAndVerifyParams,
6264
MFAChallengeParams,
6365
MFAEnrollParams,
@@ -111,7 +113,7 @@ def __init__(
111113
verify=verify,
112114
proxy=proxy,
113115
)
114-
self._jwks = {"keys": []}
116+
self._jwks: JWKSet = {"keys": []}
115117
self._storage_key = storage_key or STORAGE_KEY
116118
self._auto_refresh_token = auto_refresh_token
117119
self._persist_session = persist_session
@@ -1158,19 +1160,19 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
11581160
self._notify_all_subscribers("SIGNED_IN", response.session)
11591161
return response
11601162

1161-
async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
1162-
jwk: Dict[str, Any] = {}
1163+
async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
1164+
jwk: Optional[JWK] = None
11631165

11641166
# try fetching from the suplied keys.
1165-
jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None)
1167+
jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
11661168

11671169
if jwk:
11681170
return jwk
11691171

11701172
if self._jwks:
11711173
# try fetching from the cache.
11721174
jwk = next(
1173-
(jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid),
1175+
(jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid),
11741176
None,
11751177
)
11761178
if jwk:
@@ -1182,9 +1184,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
11821184
self._jwks = response
11831185

11841186
# find the signing key
1185-
jwk = next(
1186-
(jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None
1187-
)
1187+
jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)
11881188
if not jwk:
11891189
raise AuthInvalidJwtError("No matching signing key found in JWKS")
11901190

@@ -1193,7 +1193,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
11931193
raise AuthInvalidJwtError("JWT has no valid kid")
11941194

11951195
async def get_claims(
1196-
self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None
1196+
self, jwt: Optional[str] = None, jwks: Optional[JWKSet] = None
11971197
) -> Optional[ClaimsResponse]:
11981198
token = jwt
11991199
if not token:
@@ -1204,12 +1204,16 @@ async def get_claims(
12041204
token = session.access_token
12051205

12061206
decoded_jwt = decode_jwt(token)
1207-
payload = decoded_jwt["payload"]
1208-
header = decoded_jwt["header"]
1209-
signature = decoded_jwt["signature"]
12101207

1211-
raw_header = decoded_jwt["raw"]["header"]
1212-
raw_payload = decoded_jwt["raw"]["payload"]
1208+
payload, header, signature = (
1209+
decoded_jwt["payload"],
1210+
decoded_jwt["header"],
1211+
decoded_jwt["signature"],
1212+
)
1213+
raw_header, raw_payload = (
1214+
decoded_jwt["raw"]["header"],
1215+
decoded_jwt["raw"]["payload"],
1216+
)
12131217

12141218
validate_exp(payload["exp"])
12151219

@@ -1220,7 +1224,7 @@ async def get_claims(
12201224

12211225
algorithm = get_algorithm_by_name(header["alg"])
12221226
signing_key = algorithm.from_jwk(
1223-
await self._fetch_jwks(header["kid"], jwks or {})
1227+
await self._fetch_jwks(header["kid"], jwks or {"keys": []})
12241228
)
12251229

12261230
# verify the signature
@@ -1235,8 +1239,8 @@ async def get_claims(
12351239
return ClaimsResponse(claims=payload, headers=header, signature=signature)
12361240

12371241

1238-
def parse_jwks(response: Any) -> Dict[str, list]:
1242+
def parse_jwks(response: Any) -> JWKSet:
12391243
if "keys" not in response or len(response["keys"]) == 0:
12401244
raise AuthInvalidJwtError("JWKS is empty")
12411245

1242-
return response
1246+
return {"keys": response["keys"]}

supabase_auth/_sync/gotrue_client.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ..http_clients import SyncClient
4444
from ..timer import Timer
4545
from ..types import (
46+
JWK,
4647
AuthChangeEvent,
4748
AuthenticatorAssuranceLevels,
4849
AuthFlowType,
@@ -58,6 +59,7 @@
5859
CodeExchangeParams,
5960
DecodedJWTDict,
6061
IdentitiesResponse,
62+
JWKSet,
6163
MFAChallengeAndVerifyParams,
6264
MFAChallengeParams,
6365
MFAEnrollParams,
@@ -111,7 +113,7 @@ def __init__(
111113
verify=verify,
112114
proxy=proxy,
113115
)
114-
self._jwks = {"keys": []}
116+
self._jwks: JWKSet = {"keys": []}
115117
self._storage_key = storage_key or STORAGE_KEY
116118
self._auto_refresh_token = auto_refresh_token
117119
self._persist_session = persist_session
@@ -421,7 +423,9 @@ def sign_in_with_oauth(
421423
)
422424
return OAuthResponse(provider=provider, url=url_with_qs)
423425

424-
def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
426+
def link_identity(
427+
self, credentials: SignInWithOAuthCredentials
428+
) -> OAuthResponse:
425429
provider = credentials.get("provider")
426430
options = credentials.get("options", {})
427431
redirect_to = options.get("redirect_to")
@@ -704,7 +708,9 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
704708
self._notify_all_subscribers("TOKEN_REFRESHED", session)
705709
return AuthResponse(session=session, user=response.user)
706710

707-
def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse:
711+
def refresh_session(
712+
self, refresh_token: Optional[str] = None
713+
) -> AuthResponse:
708714
"""
709715
Returns a new session, regardless of expiry status.
710716
@@ -1113,7 +1119,9 @@ def _get_url_for_provider(
11131119
if self._flow_type == "pkce":
11141120
code_verifier = generate_pkce_verifier()
11151121
code_challenge = generate_pkce_challenge(code_verifier)
1116-
self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
1122+
self._storage.set_item(
1123+
f"{self._storage_key}-code-verifier", code_verifier
1124+
)
11171125
code_challenge_method = (
11181126
"plain" if code_verifier == code_challenge else "s256"
11191127
)
@@ -1152,19 +1160,19 @@ def exchange_code_for_session(self, params: CodeExchangeParams):
11521160
self._notify_all_subscribers("SIGNED_IN", response.session)
11531161
return response
11541162

1155-
def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
1156-
jwk: Dict[str, Any] = {}
1163+
def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
1164+
jwk: Optional[JWK] = None
11571165

11581166
# try fetching from the suplied keys.
1159-
jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None)
1167+
jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
11601168

11611169
if jwk:
11621170
return jwk
11631171

11641172
if self._jwks:
11651173
# try fetching from the cache.
11661174
jwk = next(
1167-
(jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid),
1175+
(jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid),
11681176
None,
11691177
)
11701178
if jwk:
@@ -1176,9 +1184,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
11761184
self._jwks = response
11771185

11781186
# find the signing key
1179-
jwk = next(
1180-
(jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None
1181-
)
1187+
jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)
11821188
if not jwk:
11831189
raise AuthInvalidJwtError("No matching signing key found in JWKS")
11841190

@@ -1187,7 +1193,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
11871193
raise AuthInvalidJwtError("JWT has no valid kid")
11881194

11891195
def get_claims(
1190-
self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None
1196+
self, jwt: Optional[str] = None, jwks: Optional[JWKSet] = None
11911197
) -> Optional[ClaimsResponse]:
11921198
token = jwt
11931199
if not token:
@@ -1198,12 +1204,16 @@ def get_claims(
11981204
token = session.access_token
11991205

12001206
decoded_jwt = decode_jwt(token)
1201-
payload = decoded_jwt["payload"]
1202-
header = decoded_jwt["header"]
1203-
signature = decoded_jwt["signature"]
12041207

1205-
raw_header = decoded_jwt["raw"]["header"]
1206-
raw_payload = decoded_jwt["raw"]["payload"]
1208+
payload, header, signature = (
1209+
decoded_jwt["payload"],
1210+
decoded_jwt["header"],
1211+
decoded_jwt["signature"],
1212+
)
1213+
raw_header, raw_payload = (
1214+
decoded_jwt["raw"]["header"],
1215+
decoded_jwt["raw"]["payload"],
1216+
)
12071217

12081218
validate_exp(payload["exp"])
12091219

@@ -1213,7 +1223,9 @@ def get_claims(
12131223
return ClaimsResponse(claims=payload, headers=header, signature=signature)
12141224

12151225
algorithm = get_algorithm_by_name(header["alg"])
1216-
signing_key = algorithm.from_jwk(self._fetch_jwks(header["kid"], jwks or {}))
1226+
signing_key = algorithm.from_jwk(
1227+
self._fetch_jwks(header["kid"], jwks or {"keys": []})
1228+
)
12171229

12181230
# verify the signature
12191231
is_valid = algorithm.verify(
@@ -1227,8 +1239,8 @@ def get_claims(
12271239
return ClaimsResponse(claims=payload, headers=header, signature=signature)
12281240

12291241

1230-
def parse_jwks(response: Any) -> Dict[str, list]:
1242+
def parse_jwks(response: Any) -> JWKSet:
12311243
if "keys" not in response or len(response["keys"]) == 0:
12321244
raise AuthInvalidJwtError("JWKS is empty")
12331245

1234-
return response
1246+
return {"keys": response["keys"]}

supabase_auth/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,17 @@ class ClaimsResponse(TypedDict):
816816
signature: bytes
817817

818818

819+
class JWK(TypedDict, total=False):
820+
kty: Literal["RSA", "EC", "oct"]
821+
key_ops: List[str]
822+
alg: Optional[str]
823+
kid: Optional[str]
824+
825+
826+
class JWKSet(TypedDict):
827+
keys: List[JWK]
828+
829+
819830
for model in [
820831
AMREntry,
821832
AuthResponse,

tests/_sync/test_gotrue.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,38 @@
11
import unittest
2-
import pytest
3-
from pytest_mock import mocker
42

3+
from .clients import auth_client, auth_client_with_asymmetric_session
54
from .utils import mock_user_credentials
65

7-
from .clients import auth_client, auth_client_with_asymmetric_session
86

97
def test_get_claims_returns_none_when_session_is_none():
10-
claims = auth_client().get_claims()
11-
assert claims is None
8+
claims = auth_client().get_claims()
9+
assert claims is None
10+
1211

1312
def test_get_claims_calls_get_user_if_symmetric_jwt(mocker):
14-
client = auth_client()
15-
spy = mocker.spy(client, 'get_user')
13+
client = auth_client()
14+
spy = mocker.spy(client, "get_user")
1615

17-
user = (client.sign_up(mock_user_credentials())).user
18-
assert user is not None
16+
user = (client.sign_up(mock_user_credentials())).user
17+
assert user is not None
1918

20-
claims = (client.get_claims())["claims"]
21-
assert claims["email"] == user.email
22-
spy.assert_called_once()
23-
19+
claims = (client.get_claims())["claims"]
20+
assert claims["email"] == user.email
21+
spy.assert_called_once()
2422

25-
def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
26-
client = auth_client_with_asymmetric_session()
2723

28-
user = (client.sign_up(mock_user_credentials())).user
29-
assert user is not None
24+
def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
25+
client = auth_client_with_asymmetric_session()
3026

31-
spy = mocker.spy(client, "_request")
27+
user = (client.sign_up(mock_user_credentials())).user
28+
assert user is not None
3229

33-
claims = (client.get_claims())["claims"]
34-
assert claims["email"] == user.email
30+
spy = mocker.spy(client, "_request")
3531

36-
spy.assert_called_once()
37-
spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
32+
claims = (client.get_claims())["claims"]
33+
assert claims["email"] == user.email
3834

39-
assert len(spy.spy_return.get("keys")) > 0
35+
spy.assert_called_once()
36+
spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
4037

41-
38+
assert len(spy.spy_return.get("keys")) > 0

tests/_sync/test_gotrue_admin_api.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -156,27 +156,23 @@ def test_modify_confirm_email_using_update_user_by_id():
156156

157157
def test_invalid_credential_sign_in_with_phone():
158158
try:
159-
response = (
160-
client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
161-
{
162-
"phone": "+123456789",
163-
"password": "strong_pwd",
164-
}
165-
)
159+
response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
160+
{
161+
"phone": "+123456789",
162+
"password": "strong_pwd",
163+
}
166164
)
167165
except AuthApiError as e:
168166
assert e.to_dict()
169167

170168

171169
def test_invalid_credential_sign_in_with_email():
172170
try:
173-
response = (
174-
client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
175-
{
176-
"email": "unknown_user@unknowndomain.com",
177-
"password": "strong_pwd",
178-
}
179-
)
171+
response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
172+
{
173+
"email": "unknown_user@unknowndomain.com",
174+
"password": "strong_pwd",
175+
}
180176
)
181177
except AuthApiError as e:
182178
assert e.to_dict()
@@ -390,10 +386,12 @@ def test_sign_in_with_sso():
390386

391387

392388
def test_sign_in_with_oauth():
393-
assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
394-
{
395-
"provider": "google",
396-
}
389+
assert (
390+
client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
391+
{
392+
"provider": "google",
393+
}
394+
)
397395
)
398396

399397

0 commit comments

Comments
 (0)