Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect expired token #204

Merged
merged 2 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Features

* Added functions for downloading and uploading samples: :meth:`Client.get_sample`, :meth:`Client.upload_new_sample_now`.
* Added :class:`transfer.link.LinkFileTransfer`.
* :class:`Client` and :class:`ScicatClient` now check whether a token has expired and raise and exception if it has.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
32 changes: 32 additions & 0 deletions src/scitacean/_internal/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean)
"""Tools for JSON web tokens."""

import base64
import json
from datetime import datetime, timezone
from typing import cast


def decode(token: str) -> tuple[dict[str, str | int], dict[str, str | int], str]:
"""Decode the components of a JSOn web token."""
h, p, signature = token.split(".")
header = _decode_part(h)
payload = _decode_part(p)
return header, payload, signature


def expiry(token: str) -> datetime:
"""Return the expiration time of a JWT in UTC."""
_, payload, _ = decode(token)
# 'exp' should always be given in UTC. Since we have no way of checking that,
# assume that it is the case.
return datetime.fromtimestamp(float(payload["exp"]), tz=timezone.utc)


def _decode_part(s: str) -> dict[str, str | int]:
# urlsafe_b64decode requires a properly padded input but SciCat
# doesn't pad its tokens.
padded = s + "=" * (len(s) % 4)
decoded_str = base64.urlsafe_b64decode(padded).decode("utf-8")
return cast(dict[str, str | int], json.loads(decoded_str))
6 changes: 4 additions & 2 deletions src/scitacean/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .logging import get_logger
from .pid import PID
from .typing import DownloadConnection, FileTransfer, UploadConnection
from .util.credentials import SecretStr, StrStorage
from .util.credentials import ExpiringToken, SecretStr, StrStorage


class Client:
Expand Down Expand Up @@ -566,7 +566,9 @@ def __init__(
self._base_url = url[:-1] if url.endswith("/") else url
self._timeout = datetime.timedelta(seconds=10) if timeout is None else timeout
self._token: StrStorage | None = (
SecretStr(token) if isinstance(token, str) else token
ExpiringToken.from_jwt(SecretStr(token))
if isinstance(token, str)
else token
)

@classmethod
Expand Down
79 changes: 32 additions & 47 deletions src/scitacean/util/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from __future__ import annotations

import datetime
from datetime import datetime, timedelta, timezone
from typing import NoReturn

from .._internal.jwt import expiry


class StrStorage:
"""Base class for storing a string.
Expand Down Expand Up @@ -36,41 +38,6 @@ def __repr__(self) -> str:
return f"{type(self).__name__}({self._value!r})"


# TODO implement
# class KeyringStr(StrStorage):
# """Store a string in the user's keyring.
#
# Should be the bottom level StrStorage because it erases any nested
# StrStorage objects and stores and returns plain strs.
# """
#
# def __init__(self, *, key: str, value: Optional[Union[str, StrStorage]]):
# super().__init__(None)
# # TODO dummy implementation
# self._ring = {}
# self._key = key
# if value is not None:
# if isinstance(value, StrStorage):
# self._store(value.get_str())
# else:
# self._store(value)
#
# def _store(self, value: str):
# self._ring[self._key] = value
#
# def _retrieve(self) -> str:
# return self._ring[self._key]
#
# def get_str(self) -> str:
# return self._retrieve()
#
# def __str__(self) -> str:
# return "???"
#
# def __repr__(self) -> str:
# return f"KeyringStr(key='{self._key}', value={str(self)})"


class SecretStr(StrStorage):
"""Minimize the risk of exposing a secret.

Expand All @@ -97,29 +64,47 @@ def __reduce_ex__(self, protocol: object) -> NoReturn:
raise TypeError("SecretStr must not be pickled")


class TimeLimitedStr(StrStorage):
"""A string that expires after some time."""
class ExpiringToken(StrStorage):
"""A JWT token that expires after some time."""

def __init__(
self,
*,
value: str | StrStorage,
expires_at: datetime.datetime,
tolerance: datetime.timedelta | None = None,
expires_at: datetime,
denial_period: timedelta | None = None,
):
super().__init__(value)
if tolerance is None:
tolerance = datetime.timedelta(seconds=10)
self._expires_at = expires_at - tolerance
if denial_period is None:
denial_period = timedelta(seconds=2)
self._expires_at = expires_at - denial_period
self._check_expiry()

@classmethod
def from_jwt(cls, value: str | StrStorage) -> ExpiringToken:
"""Create a new ExpiringToken from a JSON web token."""
value_str = value if isinstance(value, str) else value.get_str()
try:
expires_at = expiry(value_str)
except ValueError:
expires_at = datetime.now(tz=timezone.utc) + timedelta(weeks=100)
return cls(
value=value,
expires_at=expires_at,
)

def get_str(self) -> str:
"""Return the stored plain str object."""
if self._is_expired():
raise RuntimeError("Login has expired")
self._check_expiry()
return super().get_str()

def _is_expired(self) -> bool:
return datetime.datetime.now() > self._expires_at
def _check_expiry(self) -> None:
if datetime.now(tz=self._expires_at.tzinfo) > self._expires_at:
raise RuntimeError(
"SciCat login has expired. You need to create a new client either by "
"logging in through `Client.from_credentials` or by getting a new "
"access token from the SciCat web interface."
)

def __repr__(self) -> str:
return (
Expand Down
53 changes: 50 additions & 3 deletions tests/client/client_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean)

import base64
import json
import pickle
import time
from datetime import datetime, timedelta, timezone
from typing import Any

import pytest

from scitacean import PID, Client
from scitacean.testing.backend.seed import INITIAL_DATASETS
from scitacean.testing.client import FakeClient
from scitacean.util.credentials import SecretStr

Expand All @@ -29,9 +35,7 @@ def test_from_credentials_fake():
)


def test_from_credentials_real(scicat_access, scicat_backend):
if not scicat_backend:
pytest.skip("No backend")
def test_from_credentials_real(scicat_access, require_scicat_backend):
Client.from_credentials(url=scicat_access.url, **scicat_access.user.credentials)


Expand Down Expand Up @@ -80,3 +84,46 @@ def test_fake_can_disable_functions():
client.scicat.get_dataset_model(PID(pid="some-pid"))
with pytest.raises(IndexError, match="custom index error"):
client.scicat.get_orig_datablocks(PID(pid="some-pid"))


def encode_jwt_part(part: dict[str, Any]) -> str:
return base64.urlsafe_b64encode(json.dumps(part).encode("utf-8")).decode("ascii")


def make_token(exp_in: timedelta) -> str:
now = datetime.now(tz=timezone.utc)
exp = now + exp_in

# This is what a SciCat token looks like as of 2024-04-19
header = {"alg": "HS256", "typ": "JWT"}
payload = {
"_id": "7fc0856e50a8",
"username": "Weatherwax",
"email": "g.weatherwax@wyrd.lancre",
"authStrategy": "ldap",
"id": "7fc0856e50a8",
"userId": "7fc0856e50a8",
"iat": now.timestamp(),
"exp": exp.timestamp(),
}
# Scitacean never validates the signature because it doesn't have the secret key,
# so it doesn't matter what we use here.
signature = "123abc"

return ".".join((encode_jwt_part(header), encode_jwt_part(payload), signature))


def test_detects_expired_token_init():
token = make_token(timedelta(milliseconds=0))
with pytest.raises(RuntimeError, match="SciCat login has expired"):
Client.from_token(url="scicat.com", token=token)


def test_detects_expired_token_get_dataset(scicat_access, require_scicat_backend):
# The token is invalid, but the expiration should be detected before
# even sending it to SciCat.
token = make_token(timedelta(milliseconds=2100)) # > than denial period = 2s
client = Client.from_token(url=scicat_access.url, token=token)
time.sleep(0.5)
with pytest.raises(RuntimeError, match="SciCat login has expired"):
client.get_dataset(INITIAL_DATASETS["public"].pid) # type: ignore[arg-type]
Loading