Skip to content

Add granularity for DID Token Errors #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@

- <PR-#ISSUE> ...

## `0.2.0` - 1/04/2023

#### Added

- PR-#50: Split up DIDTokenError into DIDTokenExpired, DIDTokenMalformed, and DIDTokenInvalid.

## `0.1.0` - 11/30/2022

#### Added
Expand Down
10 changes: 9 additions & 1 deletion magic_admin/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@ def to_dict(self):
return {'message': str(self)}


class DIDTokenError(MagicError):
class DIDTokenInvalid(MagicError):
pass


class DIDTokenMalformed(MagicError):
pass


class DIDTokenExpired(MagicError):
pass


Expand Down
30 changes: 20 additions & 10 deletions magic_admin/resources/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from eth_account.messages import defunct_hash_message
from web3.auto import w3

from magic_admin.error import DIDTokenError
from magic_admin.error import DIDTokenExpired
from magic_admin.error import DIDTokenInvalid
from magic_admin.error import DIDTokenMalformed
from magic_admin.resources.base import ResourceComponent
from magic_admin.utils.did_token import parse_public_address_from_issuer
from magic_admin.utils.time import apply_did_token_nbf_grace_period
Expand Down Expand Up @@ -42,7 +44,7 @@ def _check_required_fields(cls, claim):
missing_fields.append(field)

if missing_fields:
raise DIDTokenError(
raise DIDTokenMalformed(
message='DID token is missing required field(s): {}'.format(
sorted(missing_fields),
),
Expand All @@ -55,7 +57,7 @@ def decode(cls, did_token):
did_token (base64.str): Base64 encoded string.

Raises:
DIDTokenError: If token format is invalid.
DIDTokenMalformed: If token format is invalid.

Returns:
proof (str): A signed message.
Expand All @@ -66,7 +68,7 @@ def decode(cls, did_token):
base64.urlsafe_b64decode(did_token).decode('utf-8'),
)
except Exception as e:
raise DIDTokenError(
raise DIDTokenMalformed(
message='DID token is malformed. It has to be a based64 encoded '
'JSON serialized string. {err} ({msg}).'.format(
err=e.__class__.__name__,
Expand All @@ -75,7 +77,7 @@ def decode(cls, did_token):
)

if len(decoded_did_token) != EXPECTED_DID_TOKEN_CONTENT_LENGTH:
raise DIDTokenError(
raise DIDTokenMalformed(
message='DID token is malformed. It has to have two parts '
'[proof, claim].',
)
Expand All @@ -85,7 +87,7 @@ def decode(cls, did_token):
try:
claim = json.loads(decoded_did_token[1])
except Exception as e:
raise DIDTokenError(
raise DIDTokenMalformed(
message='DID token is malformed. Given claim should be a JSON '
'serialized string. {err} ({msg}).'.format(
err=e.__class__.__name__,
Expand Down Expand Up @@ -130,12 +132,20 @@ def validate(cls, did_token):
did_token (base64.str): Base64 encoded string.

Raises:
DIDTokenError: If DID token fails the validation.
DIDTokenInvalid: If DID token fails the validation.
DIDTokenExpired: If DID token has expired.

Returns:
None.
"""
proof, claim = cls.decode(did_token)

if claim['ext'] is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this to the validate method!

raise DIDTokenInvalid(
message='Please check the "ext" field and regenerate a new token '
'with a suitable value.',
)

recovered_address = w3.eth.account.recoverHash(
defunct_hash_message(
text=json.dumps(claim, separators=(',', ':')),
Expand All @@ -144,20 +154,20 @@ def validate(cls, did_token):
)

if recovered_address != cls.get_public_address(did_token):
raise DIDTokenError(
raise DIDTokenInvalid(
message='Signature mismatch between "proof" and "claim". Please '
'generate a new token with an intended issuer.',
)

current_time_in_s = epoch_time_now()

if current_time_in_s > claim['ext']:
raise DIDTokenError(
raise DIDTokenExpired(
message='Given DID token has expired. Please generate a new one.',
)

if current_time_in_s < apply_did_token_nbf_grace_period(claim['nbf']):
raise DIDTokenError(
raise DIDTokenInvalid(
message='Given DID token cannot be used at this time. Please '
'check the "nbf" field and regenerate a new token with a suitable '
'value.',
Expand Down
4 changes: 2 additions & 2 deletions magic_admin/utils/did_token.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from magic_admin.error import DIDTokenError
from magic_admin.error import DIDTokenMalformed


def parse_public_address_from_issuer(issuer):
Expand All @@ -14,7 +14,7 @@ def parse_public_address_from_issuer(issuer):
try:
return issuer.split(':')[2]
except IndexError:
raise DIDTokenError(
raise DIDTokenMalformed(
'Given issuer ({}) is malformed. Please make sure it follows the '
'`did:method-name:method-specific-id` format.'.format(issuer),
)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/error_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from magic_admin.error import APIError
from magic_admin.error import AuthenticationError
from magic_admin.error import BadRequestError
from magic_admin.error import DIDTokenError
from magic_admin.error import DIDTokenInvalid
from magic_admin.error import ForbiddenError
from magic_admin.error import MagicError
from magic_admin.error import RateLimitingError
Expand Down Expand Up @@ -34,9 +34,9 @@ class TestMagicError(MagicErrorBase):
error_class = MagicError


class TestDIDTokenError(MagicErrorBase):
class TestDIDTokenInvalid(MagicErrorBase):

error_class = DIDTokenError
error_class = DIDTokenInvalid


class TestAPIConnectionError(MagicErrorBase):
Expand Down
27 changes: 19 additions & 8 deletions tests/unit/resources/token_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import pytest

from magic_admin.error import DIDTokenError
from magic_admin.error import DIDTokenExpired
from magic_admin.error import DIDTokenInvalid
from magic_admin.error import DIDTokenMalformed
from magic_admin.resources.token import Token


Expand All @@ -24,7 +26,7 @@ def test_required_fields(self):
) == frozenset()

def test_check_required_fields_raises_error(self):
with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenMalformed) as e:
Token._check_required_fields(
self._generate_claim(fields=['nbf', 'sub', 'aud', 'tid', 'iat']),
)
Expand Down Expand Up @@ -80,7 +82,7 @@ def setup_mocks(self):
def test_decode_raises_error_if_did_token_is_malformed(self, setup_mocks):
setup_mocks.urlsafe_b64decode.side_effect = Exception()

with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenMalformed) as e:
Token.decode(self.did_token)

setup_mocks.urlsafe_b64decode.assert_called_once_with(self.did_token)
Expand All @@ -90,7 +92,7 @@ def test_decode_raises_error_if_did_token_is_malformed(self, setup_mocks):
def test_decode_raises_error_if_did_token_has_missing_parts(self, setup_mocks):
setup_mocks.json_loads.return_value = ('miss one part')

with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenMalformed) as e:
Token.decode(self.did_token)

setup_mocks.urlsafe_b64decode.assert_called_once_with(self.did_token)
Expand All @@ -101,7 +103,7 @@ def test_decode_raises_error_if_did_token_has_missing_parts(self, setup_mocks):
'[proof, claim].'

def test_decode_raises_error_if_claim_is_not_json_serializable(self, setup_mocks):
with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenMalformed) as e:
setup_mocks.json_loads.side_effect = [
('proof_in_str', 'claim_in_str'), # Succeeds the first time.
Exception(), # Fails the second time.
Expand Down Expand Up @@ -228,7 +230,7 @@ def _assert_validate_funcs_called(
def test_validate_raises_error_if_signature_mismatch(self, setup_mocks):
setup_mocks.get_public_address.return_value = 'random_public_address'

with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenInvalid) as e:
Token.validate(self.did_token)

self._assert_validate_funcs_called(setup_mocks)
Expand All @@ -239,7 +241,7 @@ def test_validate_raises_error_if_did_token_expires(self, setup_mocks):
setup_mocks.epoch_time_now.return_value = \
setup_mocks.claim['ext'] + 1

with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenExpired) as e:
Token.validate(self.did_token)

self._assert_validate_funcs_called(
Expand All @@ -249,11 +251,20 @@ def test_validate_raises_error_if_did_token_expires(self, setup_mocks):
assert str(e.value) == 'Given DID token has expired. Please generate a ' \
'new one.'

def test_validate_raises_error_if_did_token_has_no_expiration(self, setup_mocks):
setup_mocks.claim['ext'] = None

with pytest.raises(DIDTokenInvalid) as e:
Token.validate(self.did_token)

assert str(e.value) == 'Please check the "ext" field and regenerate a new' \
' token with a suitable value.'

def test_validate_raises_error_if_did_token_used_before_nbf(self, setup_mocks):
setup_mocks.epoch_time_now.return_value = \
setup_mocks.claim['nbf'] - 1

with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenInvalid) as e:
Token.validate(self.did_token)

self._assert_validate_funcs_called(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/utils/did_token_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from magic_admin.error import DIDTokenError
from magic_admin.error import DIDTokenMalformed
from magic_admin.utils.did_token import construct_issuer_with_public_address
from magic_admin.utils.did_token import parse_public_address_from_issuer
from testing.data.did_token import issuer
Expand All @@ -15,7 +15,7 @@ def test_parse_public_address_from_issuer(self):
assert parse_public_address_from_issuer(issuer) == public_address

def test_parse_public_address_from_issuer_raises_error(self):
with pytest.raises(DIDTokenError) as e:
with pytest.raises(DIDTokenMalformed) as e:
parse_public_address_from_issuer(self.malformed_issuer)

assert str(e.value) == \
Expand Down