Skip to content

Commit

Permalink
Add expiration of unused refresh tokens (home-assistant#108428)
Browse files Browse the repository at this point in the history
Co-authored-by: J. Nick Koston <nick@koston.org>
  • Loading branch information
mib1185 and bdraco authored Jan 24, 2024
1 parent 0d22822 commit f5d4397
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 7 deletions.
71 changes: 67 additions & 4 deletions homeassistant/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,27 @@
import asyncio
from collections import OrderedDict
from collections.abc import Mapping
from datetime import timedelta
from datetime import datetime, timedelta
from functools import partial
import time
from typing import Any, cast

import jwt

from homeassistant import data_entry_flow
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.core import (
CALLBACK_TYPE,
HassJob,
HassJobType,
HomeAssistant,
callback,
)
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util

from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .providers import AuthProvider, LoginFlow, auth_provider_from_config

Expand Down Expand Up @@ -75,7 +83,9 @@ async def auth_manager_from_config(
for module in modules:
module_hash[module.id] = module

return AuthManager(hass, store, provider_hash, module_hash)
manager = AuthManager(hass, store, provider_hash, module_hash)
manager.async_setup()
return manager


class AuthManagerFlowManager(data_entry_flow.FlowManager):
Expand Down Expand Up @@ -159,6 +169,21 @@ def __init__(
self._mfa_modules = mfa_modules
self.login_flow = AuthManagerFlowManager(hass, self)
self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
self._expire_callback: CALLBACK_TYPE | None = None
self._remove_expired_job = HassJob(
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
)

@callback
def async_setup(self) -> None:
"""Set up the auth manager."""
hass = self.hass
hass.async_add_shutdown_job(
HassJob(
self._async_cancel_expiration_schedule, job_type=HassJobType.Callback
)
)
self._async_track_next_refresh_token_expiration()

@property
def auth_providers(self) -> list[AuthProvider]:
Expand Down Expand Up @@ -424,6 +449,11 @@ async def async_create_refresh_token(
else:
token_type = models.TOKEN_TYPE_NORMAL

if token_type is models.TOKEN_TYPE_NORMAL:
expire_at = time.time() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = None

if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
raise ValueError(
"System generated users can only have system type refresh tokens"
Expand Down Expand Up @@ -455,6 +485,7 @@ async def async_create_refresh_token(
client_icon,
token_type,
access_token_expiration,
expire_at,
credential,
)

Expand All @@ -479,6 +510,38 @@ def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None
for revoke_callback in callbacks:
revoke_callback()

@callback
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
"""Remove expired refresh tokens."""
now = time.time()
for token in self._store.async_get_refresh_tokens()[:]:
if (expire_at := token.expire_at) is not None and expire_at <= now:
self.async_remove_refresh_token(token)
self._async_track_next_refresh_token_expiration()

@callback
def _async_track_next_refresh_token_expiration(self) -> None:
"""Initialise all token expiration scheduled tasks."""
next_expiration = time.time() + REFRESH_TOKEN_EXPIRATION
for token in self._store.async_get_refresh_tokens():
if (
expire_at := token.expire_at
) is not None and expire_at < next_expiration:
next_expiration = expire_at

self._expire_callback = async_track_point_in_utc_time(
self.hass,
self._remove_expired_job,
dt_util.utc_from_timestamp(next_expiration),
)

@callback
def _async_cancel_expiration_schedule(self) -> None:
"""Cancel tracking of expired refresh tokens."""
if self._expire_callback:
self._expire_callback()
self._expire_callback = None

@callback
def _async_unregister(
self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE
Expand Down
33 changes: 32 additions & 1 deletion homeassistant/auth/auth_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from datetime import timedelta
import hmac
import itertools
from logging import getLogger
from typing import Any

Expand All @@ -17,6 +18,7 @@
GROUP_ID_ADMIN,
GROUP_ID_READ_ONLY,
GROUP_ID_USER,
REFRESH_TOKEN_EXPIRATION,
)
from .permissions import system_policies
from .permissions.models import PermissionLookup
Expand Down Expand Up @@ -186,6 +188,7 @@ async def async_create_refresh_token(
client_icon: str | None = None,
token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
expire_at: float | None = None,
credential: models.Credentials | None = None,
) -> models.RefreshToken:
"""Create a new token for a user."""
Expand All @@ -194,6 +197,7 @@ async def async_create_refresh_token(
"client_id": client_id,
"token_type": token_type,
"access_token_expiration": access_token_expiration,
"expire_at": expire_at,
"credential": credential,
}
if client_name:
Expand Down Expand Up @@ -239,16 +243,29 @@ def async_get_refresh_token_by_token(

return found

@callback
def async_get_refresh_tokens(self) -> list[models.RefreshToken]:
"""Get all refresh tokens."""
return list(
itertools.chain.from_iterable(
user.refresh_tokens.values() for user in self._users.values()
)
)

@callback
def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
) -> None:
"""Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow()
refresh_token.last_used_ip = remote_ip
if refresh_token.expire_at:
refresh_token.expire_at = (
refresh_token.last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
)
self._async_schedule_save()

async def async_load(self) -> None:
async def async_load(self) -> None: # noqa: C901
"""Load the users."""
if self._loaded:
raise RuntimeError("Auth storage is already loaded")
Expand All @@ -261,6 +278,8 @@ async def async_load(self) -> None:
perm_lookup = PermissionLookup(ent_reg, dev_reg)
self._perm_lookup = perm_lookup

now_ts = dt_util.utcnow().timestamp()

if data is None or not isinstance(data, dict):
self._set_defaults()
return
Expand Down Expand Up @@ -414,6 +433,14 @@ async def async_load(self) -> None:
else:
last_used_at = None

if (
expire_at := rt_dict.get("expire_at")
) is None and token_type == models.TOKEN_TYPE_NORMAL:
if last_used_at:
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION

token = models.RefreshToken(
id=rt_dict["id"],
user=users[rt_dict["user_id"]],
Expand All @@ -430,6 +457,7 @@ async def async_load(self) -> None:
jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at,
last_used_ip=rt_dict.get("last_used_ip"),
expire_at=expire_at,
version=rt_dict.get("version"),
)
if "credential_id" in rt_dict:
Expand All @@ -439,6 +467,8 @@ async def async_load(self) -> None:
self._groups = groups
self._users = users

self._async_schedule_save()

@callback
def _async_schedule_save(self) -> None:
"""Save users."""
Expand Down Expand Up @@ -503,6 +533,7 @@ def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
if refresh_token.last_used_at
else None,
"last_used_ip": refresh_token.last_used_ip,
"expire_at": refresh_token.expire_at,
"credential_id": refresh_token.credential.id
if refresh_token.credential
else None,
Expand Down
1 change: 1 addition & 0 deletions homeassistant/auth/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
MFA_SESSION_EXPIRATION = timedelta(minutes=5)
REFRESH_TOKEN_EXPIRATION = timedelta(days=90).total_seconds()

GROUP_ID_ADMIN = "system-admin"
GROUP_ID_USER = "system-users"
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class RefreshToken:
last_used_at: datetime | None = attr.ib(default=None)
last_used_ip: str | None = attr.ib(default=None)

expire_at: float | None = attr.ib(default=None)

credential: Credentials | None = attr.ib(default=None)

version: str | None = attr.ib(default=__version__)
Expand Down
67 changes: 67 additions & 0 deletions tests/auth/test_auth_store.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Tests for the auth store."""
import asyncio
from datetime import timedelta
from typing import Any
from unittest.mock import patch

from freezegun import freeze_time
import pytest

from homeassistant.auth import auth_store
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util


async def test_loading_no_group_data_format(
Expand Down Expand Up @@ -267,3 +270,67 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
mock_dev_registry.assert_called_once_with(hass)
mock_load.assert_called_once_with()
assert results[0] == results[1]


async def test_add_expire_at_property(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test we correctly add expired_at property if not existing."""
now = dt_util.utcnow()
with freeze_time(now):
hass_storage[auth_store.STORAGE_KEY] = {
"version": 1,
"data": {
"credentials": [],
"users": [
{
"id": "user-id",
"is_active": True,
"is_owner": True,
"name": "Paulus",
"system_generated": False,
},
{
"id": "system-id",
"is_active": True,
"is_owner": True,
"name": "Hass.io",
"system_generated": True,
},
],
"refresh_tokens": [
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id",
"jwt_key": "some-key",
"last_used_at": str(now - timedelta(days=10)),
"token": "some-token",
"user_id": "user-id",
"version": "1.2.3",
},
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id2",
"jwt_key": "some-key2",
"token": "some-token",
"user_id": "user-id",
},
],
},
}

store = auth_store.AuthStore(hass)
await store.async_load()

users = await store.async_get_users()

assert len(users[0].refresh_tokens) == 2
token1, token2 = users[0].refresh_tokens.values()
assert token1.expire_at
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds()
assert token2.expire_at
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()
Loading

0 comments on commit f5d4397

Please sign in to comment.