Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,6 @@ cython_debug/
*.txt
temp/

# MCP stuff:
.tools/
.vscode/mcp.json
8 changes: 3 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
@pytest.fixture(autouse=True)
def reset_auth_cache() -> Generator[None, None, None]:
"""Reset auth cache before each test to prevent test interference."""
# Clear the LRU cache for the load_openid_config_and_jwks function
if hasattr(utils.auth, "load_openid_config_and_jwks"):
utils.auth.load_openid_config_and_jwks.cache_clear()
# Use the public API to clear cache (Issue #143)
utils.auth.clear_jwks_cache()
yield
# Clean up after test
if hasattr(utils.auth, "load_openid_config_and_jwks"):
utils.auth.load_openid_config_and_jwks.cache_clear()
utils.auth.clear_jwks_cache()


@pytest.fixture
Expand Down
7 changes: 4 additions & 3 deletions run_api.bat
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
@echo off
REM activate your venv (adjust path as needed)
call C:\Users\Jeff\Desktop\core-api\venv\Scripts\activate.bat
call C:\Users\Jeff\Desktop\core-api\.venv\Scripts\activate.bat

REM Change to working directory
cd /d C:\Users\Jeff\Desktop\core-api

REM start Flask via python in the background, append output to a log
start "Flask" py app.py >> C:\Users\Jeff\Desktop\core-api\api.log 2>&1
REM start Flask in the foreground so Task Scheduler can manage it
REM Using start /B runs in same console but background (no new window)
start /B py app.py >> C:\Users\Jeff\Desktop\core-api\api.log 2>&1

REM Health check loop (max 60s)
setlocal ENABLEDELAYEDEXPANSION
Expand Down
146 changes: 133 additions & 13 deletions utils/auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
import os
import threading
import time

# ----------------------------
# Dataclass for Azure settings
# ----------------------------
from dataclasses import dataclass
from functools import lru_cache, wraps
from typing import Any, Callable, Dict, Tuple
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple

import jwt
import requests
Expand Down Expand Up @@ -98,24 +97,91 @@ def from_env(cls) -> "AzureConfig":
# ----------------------------
# OpenID / JWKS caching
# ----------------------------
# We cache the OpenID configuration and JWKS in-memory. In production, keys rotate infrequently,
# but you may want a TTL-based refresh. Here we simply cache once per process lifetime.
# We implement a TTL-based cache with automatic refresh on key ID mismatch.
# This prevents downtime during Azure AD key rotations (Issue #143).


# Module-level cache with TTL (thread-safe)
_jwks_cache: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None
_jwks_cache_timestamp: float = 0
_jwks_cache_ttl: int = 3600 # 1 hour in seconds
_jwks_cache_lock = threading.Lock() # Protects cache reads/writes

@lru_cache(maxsize=1)
def load_openid_config_and_jwks() -> Tuple[Dict[str, Any], Dict[str, Any]]:

def clear_jwks_cache() -> None:
"""
Clear the JWKS cache, forcing a fresh fetch on next validation.
Useful for testing and manual cache invalidation.
"""
global _jwks_cache, _jwks_cache_timestamp
with _jwks_cache_lock:
_jwks_cache = None
_jwks_cache_timestamp = 0
logging.info("JWKS cache manually cleared")


def load_openid_config_and_jwks(
force_refresh: bool = False,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Load the OpenID Connect configuration and JWKS for Azure AD.
Cached for process lifetime. Raises on errors.
Cached for TTL duration (default 1 hour). Can be force refreshed.
Thread-safe implementation with lock protection.

Args:
force_refresh: If True, bypass cache and fetch fresh JWKS

Returns:
Tuple of (openid_config, jwks) dictionaries

Raises:
RuntimeError: If unable to fetch configuration and no cache available
"""
global _jwks_cache, _jwks_cache_timestamp

# Check cache with read lock
with _jwks_cache_lock:
current_time = time.time()
cache_age = current_time - _jwks_cache_timestamp

# Return cached value if valid and not force refresh
if (
not force_refresh
and _jwks_cache is not None
and cache_age < _jwks_cache_ttl
):
logging.debug(f"Using cached JWKS (age: {cache_age:.0f}s)")
return _jwks_cache

# Note the reason for refresh (for logging outside lock)
refresh_reason = (
"forced"
if force_refresh
else f"TTL expired ({cache_age:.0f}s > {_jwks_cache_ttl}s)"
)

# Fetch fresh JWKS (outside lock to avoid blocking during network I/O)
logging.info(
f"Fetching fresh OpenID configuration and JWKS from Azure AD ({refresh_reason})"
)

cfg = AzureConfig.from_env()
well_known = f"https://login.microsoftonline.com/{cfg.tenant_id}/v2.0/.well-known/openid-configuration"

try:
resp = requests.get(well_known, timeout=5)
resp.raise_for_status()
openid_cfg = resp.json()
except Exception as e:
logging.error(f"Failed to fetch OpenID configuration from {well_known}: {e}")
# If we have stale cache, return it as fallback
with _jwks_cache_lock:
if _jwks_cache is not None:
cache_age = time.time() - _jwks_cache_timestamp
logging.warning(
f"Using stale JWKS cache (age: {cache_age:.0f}s) due to fetch failure"
)
return _jwks_cache
raise RuntimeError(f"Failed to load OpenID configuration: {e}")

jwks_uri = openid_cfg.get("jwks_uri")
Expand All @@ -129,19 +195,63 @@ def load_openid_config_and_jwks() -> Tuple[Dict[str, Any], Dict[str, Any]]:
jwks = resp2.json()
except Exception as e:
logging.error(f"Failed to fetch JWKS from {jwks_uri}: {e}")
# If we have stale cache, return it as fallback
with _jwks_cache_lock:
if _jwks_cache is not None:
cache_age = time.time() - _jwks_cache_timestamp
logging.warning(
f"Using stale JWKS cache (age: {cache_age:.0f}s) due to fetch failure"
)
return _jwks_cache
raise RuntimeError(f"Failed to load JWKS: {e}")

# Update cache with write lock
with _jwks_cache_lock:
_jwks_cache = (openid_cfg, jwks)
_jwks_cache_timestamp = time.time()

key_count = len(jwks.get("keys", []))
logging.info(f"JWKS cache refreshed with {key_count} keys")

return openid_cfg, jwks


# ----------------------------
# Token validation
# ----------------------------
def validate_token(token: str) -> Dict[str, Any]:
def _find_jwk_by_kid(jwks: Dict[str, Any], kid: str) -> Optional[Dict[str, Any]]:
"""
Find a JWK by key ID in the JWKS.

Args:
jwks: The JSON Web Key Set
kid: The key ID to find

Returns:
The matching JWK dict, or None if not found
"""
return next((k for k in jwks.get("keys", []) if k.get("kid") == kid), None)


def validate_token(token: str, retry_on_kid_mismatch: bool = True) -> Dict[str, Any]:
"""
Validate the JWT access token using Azure AD JWKS.
Returns the decoded payload (claims) if valid.
Raises an exception (InvalidTokenError or RuntimeError) on failure.

Args:
token: The JWT access token to validate
retry_on_kid_mismatch: If True, automatically refresh JWKS once if kid not found.
Defaults to True to handle Azure AD key rotations gracefully (Issue #143).
This is the desired behavior for production - only set to False for testing
specific error conditions.

Returns:
Dictionary of JWT claims

Raises:
InvalidTokenError: If token is invalid, expired, or missing required claims
RuntimeError: If unable to load OpenID configuration
"""
openid_cfg, jwks = load_openid_config_and_jwks()
cfg = AzureConfig.from_env()
Expand All @@ -157,9 +267,19 @@ def validate_token(token: str) -> Dict[str, Any]:
raise InvalidTokenError("JWT header missing 'kid'")

# Find matching key in JWKS
jwk_key = next((k for k in jwks.get("keys", []) if k.get("kid") == kid), None)
jwk_key = _find_jwk_by_kid(jwks, kid)

# If key not found and retry is enabled, refresh JWKS and try again (Issue #143)
if jwk_key is None and retry_on_kid_mismatch:
logging.warning(f"Key ID {kid} not found in cached JWKS, refreshing...")
openid_cfg, jwks = load_openid_config_and_jwks(force_refresh=True)
jwk_key = _find_jwk_by_kid(jwks, kid)

if jwk_key is None:
raise InvalidTokenError(f"Key ID {kid} not found in JWKS")
available_kids = [k.get("kid") for k in jwks.get("keys", [])]
raise InvalidTokenError(
f"Key ID {kid} not found in JWKS. Available keys: {available_kids}"
)

# PyJWT's PyJWK can convert JWK dict to a key object
try:
Expand Down
2 changes: 1 addition & 1 deletion utils/get_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def get_access_token() -> str:

if __name__ == "__main__":
token = get_access_token()
print(f"Access Token: {token}")
print(token)
Loading