Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
DATABASE_URL="postgresql://localhost:5432/postgres"
APP_MODULE="tasktimer.main:app"

# AWS Cognito Configuration
COGNITO_REGION=us-east-1
COGNITO_USER_POOL_ID=us-east-1_xxxxxxxxx
COGNITO_APP_CLIENT_ID=xxxxxxxxxxxxxxxxxxxxxxxxxx
AUTH_DISABLED=false
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"black>=26.1.0",
"cachetools>=5.3.0",
"databases[postgresql,sqlite]>=0.9.0",
"fastapi>=0.128.0",
"httpx>=0.28.1",
"langchain>=1.2.6",
"langchain-anthropic>=1.3.1",
"psycopg2-binary>=2.9.11",
"pydantic>=2.12.5",
"pydantic-settings>=2.0.0",
"python-dotenv>=1.2.1",
"python-jose[cryptography]>=3.3.0",
"requests>=2.32.5",
"sqlalchemy>=2.0.46",
"uvicorn>=0.40.0",
Expand Down
4 changes: 4 additions & 0 deletions tasktimer/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .dependencies import get_current_user
from .models import AuthenticatedUser

__all__ = ["get_current_user", "AuthenticatedUser"]
88 changes: 88 additions & 0 deletions tasktimer/auth/cognito.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import time
from typing import Any

import requests
from cachetools import TTLCache
from jose import jwt, JWTError

from .config import settings

# Cache JWKS keys for 1 hour
_jwks_cache: TTLCache[str, dict[str, Any]] = TTLCache(maxsize=1, ttl=3600)
_JWKS_CACHE_KEY = "jwks"


def get_jwks() -> dict[str, Any]:
"""Fetch JWKS from Cognito, with 1-hour TTL caching."""
if _JWKS_CACHE_KEY in _jwks_cache:
return _jwks_cache[_JWKS_CACHE_KEY]

response = requests.get(settings.jwks_url, timeout=10)
response.raise_for_status()
jwks = response.json()
_jwks_cache[_JWKS_CACHE_KEY] = jwks
return jwks


def get_signing_key(token: str) -> dict[str, Any]:
"""Get the signing key for the given token from JWKS."""
jwks = get_jwks()
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")

for key in jwks.get("keys", []):
if key.get("kid") == kid:
return key

raise JWTError("Unable to find matching key in JWKS")


def decode_and_validate_token(token: str) -> dict[str, Any]:
"""
Decode and validate a JWT token from Cognito.

Validates:
- Signature using JWKS
- Expiry (exp claim)
- Issuer (iss claim)
- Audience (client_id in token_use=access or aud in token_use=id)
"""
signing_key = get_signing_key(token)

try:
payload = jwt.decode(
token,
signing_key,
algorithms=["RS256"],
issuer=settings.issuer,
audience=settings.COGNITO_APP_CLIENT_ID,
options={
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
},
)
except JWTError:
# For access tokens, audience is in 'client_id' claim, not 'aud'
# Try again without audience verification, then manually check
payload = jwt.decode(
token,
signing_key,
algorithms=["RS256"],
issuer=settings.issuer,
options={
"verify_aud": False,
"verify_iss": True,
"verify_exp": True,
},
)
# Verify client_id for access tokens
if payload.get("client_id") != settings.COGNITO_APP_CLIENT_ID:
raise JWTError("Invalid audience/client_id")

return payload


def clear_jwks_cache() -> None:
"""Clear the JWKS cache. Useful for testing."""
_jwks_cache.clear()
30 changes: 30 additions & 0 deletions tasktimer/auth/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from functools import cached_property
from pydantic_settings import BaseSettings


class CognitoSettings(BaseSettings):
COGNITO_REGION: str = "us-east-1"
COGNITO_USER_POOL_ID: str = ""
COGNITO_APP_CLIENT_ID: str = ""
AUTH_DISABLED: bool = False

@cached_property
def jwks_url(self) -> str:
return (
f"https://cognito-idp.{self.COGNITO_REGION}.amazonaws.com/"
f"{self.COGNITO_USER_POOL_ID}/.well-known/jwks.json"
)

@cached_property
def issuer(self) -> str:
return (
f"https://cognito-idp.{self.COGNITO_REGION}.amazonaws.com/"
f"{self.COGNITO_USER_POOL_ID}"
)

class Config:
env_file = ".env"
extra = "ignore"


settings = CognitoSettings()
56 changes: 56 additions & 0 deletions tasktimer/auth/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError

from .config import settings
from .cognito import decode_and_validate_token
from .models import AuthenticatedUser

security = HTTPBearer(auto_error=False)


async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(security),
) -> AuthenticatedUser:
"""
FastAPI dependency that extracts and validates JWT from Authorization header.

Returns mock user when AUTH_DISABLED=true for local development.
Raises 401 for missing or invalid tokens when auth is enabled.
"""
if settings.AUTH_DISABLED:
return AuthenticatedUser(
user_id="dev-user-001",
email="dev@localhost",
)

if credentials is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)

token = credentials.credentials
try:
payload = decode_and_validate_token(token)
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"},
)

# Extract user info from token
# 'sub' is the Cognito user ID (UUID format)
user_id = payload.get("sub")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token missing subject claim",
headers={"WWW-Authenticate": "Bearer"},
)

email = payload.get("email")

return AuthenticatedUser(user_id=user_id, email=email)
7 changes: 7 additions & 0 deletions tasktimer/auth/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Optional
from pydantic import BaseModel


class AuthenticatedUser(BaseModel):
user_id: str
email: Optional[str] = None
2 changes: 1 addition & 1 deletion tasktimer/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"tasks",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("user_id", sqlalchemy.Integer),
sqlalchemy.Column("user_id", sqlalchemy.String(36)),
sqlalchemy.Column("description", sqlalchemy.String),
sqlalchemy.Column("start_time", sqlalchemy.DateTime),
sqlalchemy.Column("end_time", sqlalchemy.DateTime),
Expand Down
45 changes: 31 additions & 14 deletions tasktimer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from datetime import time, datetime
from .database import database, task_table
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Depends
from .models import TaskOut, NewTaskItem, TaskItem
from .auth import get_current_user, AuthenticatedUser


@asynccontextmanager
Expand All @@ -17,20 +18,35 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)


@app.get("/health")
async def health():
"""Health check endpoint for load balancer."""
return {"status": "healthy"}


@app.post("/track", response_model=TaskItem)
async def track(task: NewTaskItem):
print(task)
data = {**task.model_dump(), "start_time": datetime.now()}
async def track(
task: NewTaskItem,
current_user: AuthenticatedUser = Depends(get_current_user),
):
data = {
"description": task.description,
"user_id": current_user.user_id,
"start_time": datetime.now(),
}
query = task_table.insert().values(data)
last_record_id = await database.execute(query)

return {"id": last_record_id, "user_id": task.user_id}
return {"id": last_record_id, "user_id": current_user.user_id}


@app.post("/stop", response_model=TaskItem)
async def stop(task: TaskItem):
async def stop(
task_id: int,
current_user: AuthenticatedUser = Depends(get_current_user),
):
select_query = task_table.select().where(
(task_table.c.id == task.id) & (task_table.c.user_id == task.user_id)
(task_table.c.id == task_id) & (task_table.c.user_id == current_user.user_id)
)
existing_task = await database.fetch_one(select_query)
if existing_task is None:
Expand All @@ -39,24 +55,25 @@ async def stop(task: TaskItem):
end_time = datetime.now()
update_query = (
task_table.update()
.where(task_table.c.id == task.id)
.where(task_table.c.user_id == task.user_id)
.where(task_table.c.id == task_id)
.where(task_table.c.user_id == current_user.user_id)
.values(end_time=end_time)
)
await database.execute(update_query)
return task
return {"id": task_id, "user_id": current_user.user_id}


@app.get("/times", response_model=list[TaskOut])
async def get_times(user_id: int, date: str):
# Using `int` and `str` parameters instead of Pydantic models
# tells FastAPI we want to get these values from the query string.
async def get_times(
date: str,
current_user: AuthenticatedUser = Depends(get_current_user),
):
selected_date = datetime.strptime(date, "%Y-%m-%d").date()
start_of_day = datetime.combine(selected_date, time.min)
end_of_day = datetime.combine(selected_date, time.max)

query = task_table.select().where(
(task_table.c.user_id == user_id)
(task_table.c.user_id == current_user.user_id)
& (task_table.c.start_time <= end_of_day)
& ((task_table.c.end_time >= start_of_day) | (task_table.c.end_time.is_(None)))
)
Expand Down
3 changes: 1 addition & 2 deletions tasktimer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@


class NewTaskItem(BaseModel):
user_id: int
description: str


class TaskItem(BaseModel):
id: int
user_id: int
user_id: str


class TaskOut(BaseModel):
Expand Down
Loading