Skip to content

Commit ca9fcf3

Browse files
authored
fix: enroll mfa totp (#693)
1 parent b8ad40a commit ca9fcf3

File tree

5 files changed

+220
-8
lines changed

5 files changed

+220
-8
lines changed

supabase_auth/_async/gotrue_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
878878
"DELETE",
879879
f"factors/{params.get('factor_id')}",
880880
jwt=session.access_token,
881-
xform=partial(AuthMFAUnenrollResponse, model_validate),
881+
xform=partial(model_validate, AuthMFAUnenrollResponse),
882882
)
883883

884884
async def _list_factors(self) -> AuthMFAListFactorsResponse:

supabase_auth/_sync/gotrue_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
874874
"DELETE",
875875
f"factors/{params.get('factor_id')}",
876876
jwt=session.access_token,
877-
xform=partial(AuthMFAUnenrollResponse, model_validate),
877+
xform=partial(model_validate, AuthMFAUnenrollResponse),
878878
)
879879

880880
def _list_factors(self) -> AuthMFAListFactorsResponse:

supabase_auth/types.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,13 +520,21 @@ class GenerateEmailChangeLinkParams(TypedDict):
520520
]
521521

522522

523-
class MFAEnrollParams(TypedDict):
524-
factor_type: Literal["totp", "phone"]
523+
class MFAEnrollTOTPParams(TypedDict):
524+
factor_type: Literal["totp"]
525525
issuer: NotRequired[str]
526526
friendly_name: NotRequired[str]
527+
528+
529+
class MFAEnrollPhoneParams(TypedDict):
530+
factor_type: Literal["phone"]
531+
friendly_name: NotRequired[str]
527532
phone: str
528533

529534

535+
MFAEnrollParams = Union[MFAEnrollTOTPParams, MFAEnrollPhoneParams]
536+
537+
530538
class MFAUnenrollParams(TypedDict):
531539
factor_id: str
532540
"""
@@ -644,11 +652,17 @@ class AuthMFAEnrollResponse(BaseModel):
644652
"""
645653
Friendly name of the factor, useful for distinguishing between factors
646654
"""
647-
phone: str
655+
phone: Optional[str] = None
648656
"""
649657
Phone number of the MFA factor in E.164 format. Used to send messages
650658
"""
651659

660+
@model_validator_v1_v2_compat
661+
def validate_phone_required_for_phone_type(cls, values: dict) -> dict:
662+
if values.get("type") == "phone" and not values.get("phone"):
663+
raise ValueError("phone is required when type is 'phone'")
664+
return values
665+
652666

653667
class AuthMFAUnenrollResponse(BaseModel):
654668
id: str
@@ -666,7 +680,7 @@ class AuthMFAChallengeResponse(BaseModel):
666680
"""
667681
Timestamp in UNIX seconds when this challenge will no longer be usable.
668682
"""
669-
factor_type: Literal["totp", "phone"]
683+
factor_type: Optional[Literal["totp", "phone"]] = None
670684
"""
671685
Factor Type which generated the challenge
672686
"""

tests/_async/test_gotrue.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
88
from supabase_auth.helpers import decode_jwt
99

10-
from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
10+
from .clients import (
11+
GOTRUE_JWT_SECRET,
12+
auth_client,
13+
auth_client_with_asymmetric_session,
14+
auth_client_with_session,
15+
)
1116
from .utils import mock_user_credentials
1217

1318

@@ -189,3 +194,97 @@ async def test_set_session_with_invalid_token():
189194
# Try to set the session with invalid tokens
190195
with pytest.raises(AuthInvalidJwtError):
191196
await client.set_session("invalid.token.here", "invalid_refresh_token")
197+
198+
199+
async def test_mfa_enroll():
200+
client = auth_client_with_session()
201+
202+
credentials = mock_user_credentials()
203+
204+
# First sign up to get a valid session
205+
await client.sign_up(
206+
{
207+
"email": credentials.get("email"),
208+
"password": credentials.get("password"),
209+
}
210+
)
211+
212+
# Test MFA enrollment
213+
enroll_response = await client.mfa.enroll(
214+
{"issuer": "test-issuer", "factor_type": "totp", "friendly_name": "test-factor"}
215+
)
216+
217+
assert enroll_response.id is not None
218+
assert enroll_response.type == "totp"
219+
assert enroll_response.friendly_name == "test-factor"
220+
assert enroll_response.totp.qr_code is not None
221+
222+
223+
async def test_mfa_challenge():
224+
client = auth_client()
225+
credentials = mock_user_credentials()
226+
227+
# First sign up to get a valid session
228+
signup_response = await client.sign_up(
229+
{
230+
"email": credentials.get("email"),
231+
"password": credentials.get("password"),
232+
}
233+
)
234+
assert signup_response.session is not None
235+
236+
# Enroll a factor first
237+
enroll_response = await client.mfa.enroll(
238+
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
239+
)
240+
241+
# Test MFA challenge
242+
challenge_response = await client.mfa.challenge({"factor_id": enroll_response.id})
243+
assert challenge_response.id is not None
244+
assert challenge_response.expires_at is not None
245+
246+
247+
async def test_mfa_unenroll():
248+
client = auth_client()
249+
credentials = mock_user_credentials()
250+
251+
# First sign up to get a valid session
252+
signup_response = await client.sign_up(
253+
{
254+
"email": credentials.get("email"),
255+
"password": credentials.get("password"),
256+
}
257+
)
258+
assert signup_response.session is not None
259+
260+
# Enroll a factor first
261+
enroll_response = await client.mfa.enroll(
262+
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
263+
)
264+
265+
# Test MFA unenroll
266+
unenroll_response = await client.mfa.unenroll({"factor_id": enroll_response.id})
267+
assert unenroll_response.id == enroll_response.id
268+
269+
270+
async def test_mfa_list_factors():
271+
client = auth_client()
272+
credentials = mock_user_credentials()
273+
274+
# First sign up to get a valid session
275+
signup_response = await client.sign_up(
276+
{
277+
"email": credentials.get("email"),
278+
"password": credentials.get("password"),
279+
}
280+
)
281+
assert signup_response.session is not None
282+
283+
# Enroll a factor first
284+
await client.mfa.enroll(
285+
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
286+
)
287+
288+
# Test MFA list factors
289+
list_response = await client.mfa.list_factors()
290+
assert len(list_response.all) == 1

tests/_sync/test_gotrue.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
88
from supabase_auth.helpers import decode_jwt
99

10-
from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
10+
from .clients import (
11+
GOTRUE_JWT_SECRET,
12+
auth_client,
13+
auth_client_with_asymmetric_session,
14+
auth_client_with_session,
15+
)
1116
from .utils import mock_user_credentials
1217

1318

@@ -189,3 +194,97 @@ def test_set_session_with_invalid_token():
189194
# Try to set the session with invalid tokens
190195
with pytest.raises(AuthInvalidJwtError):
191196
client.set_session("invalid.token.here", "invalid_refresh_token")
197+
198+
199+
def test_mfa_enroll():
200+
client = auth_client_with_session()
201+
202+
credentials = mock_user_credentials()
203+
204+
# First sign up to get a valid session
205+
client.sign_up(
206+
{
207+
"email": credentials.get("email"),
208+
"password": credentials.get("password"),
209+
}
210+
)
211+
212+
# Test MFA enrollment
213+
enroll_response = client.mfa.enroll(
214+
{"issuer": "test-issuer", "factor_type": "totp", "friendly_name": "test-factor"}
215+
)
216+
217+
assert enroll_response.id is not None
218+
assert enroll_response.type == "totp"
219+
assert enroll_response.friendly_name == "test-factor"
220+
assert enroll_response.totp.qr_code is not None
221+
222+
223+
def test_mfa_challenge():
224+
client = auth_client()
225+
credentials = mock_user_credentials()
226+
227+
# First sign up to get a valid session
228+
signup_response = client.sign_up(
229+
{
230+
"email": credentials.get("email"),
231+
"password": credentials.get("password"),
232+
}
233+
)
234+
assert signup_response.session is not None
235+
236+
# Enroll a factor first
237+
enroll_response = client.mfa.enroll(
238+
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
239+
)
240+
241+
# Test MFA challenge
242+
challenge_response = client.mfa.challenge({"factor_id": enroll_response.id})
243+
assert challenge_response.id is not None
244+
assert challenge_response.expires_at is not None
245+
246+
247+
def test_mfa_unenroll():
248+
client = auth_client()
249+
credentials = mock_user_credentials()
250+
251+
# First sign up to get a valid session
252+
signup_response = client.sign_up(
253+
{
254+
"email": credentials.get("email"),
255+
"password": credentials.get("password"),
256+
}
257+
)
258+
assert signup_response.session is not None
259+
260+
# Enroll a factor first
261+
enroll_response = client.mfa.enroll(
262+
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
263+
)
264+
265+
# Test MFA unenroll
266+
unenroll_response = client.mfa.unenroll({"factor_id": enroll_response.id})
267+
assert unenroll_response.id == enroll_response.id
268+
269+
270+
def test_mfa_list_factors():
271+
client = auth_client()
272+
credentials = mock_user_credentials()
273+
274+
# First sign up to get a valid session
275+
signup_response = client.sign_up(
276+
{
277+
"email": credentials.get("email"),
278+
"password": credentials.get("password"),
279+
}
280+
)
281+
assert signup_response.session is not None
282+
283+
# Enroll a factor first
284+
client.mfa.enroll(
285+
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
286+
)
287+
288+
# Test MFA list factors
289+
list_response = client.mfa.list_factors()
290+
assert len(list_response.all) == 1

0 commit comments

Comments
 (0)