Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/policyengine_api/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
simulations,
tax_benefit_model_versions,
tax_benefit_models,
user_policies,
variables,
)

Expand All @@ -35,5 +36,6 @@
api_router.include_router(household.router)
api_router.include_router(analysis.router)
api_router.include_router(agent.router)
api_router.include_router(user_policies.router)

__all__ = ["api_router"]
27 changes: 21 additions & 6 deletions src/policyengine_api/api/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session, select

from policyengine_api.models import (
Expand All @@ -40,6 +40,7 @@
Policy,
PolicyCreate,
PolicyRead,
TaxBenefitModel,
)
from policyengine_api.services.database import get_session

Expand Down Expand Up @@ -67,8 +68,17 @@ def create_policy(policy: PolicyCreate, session: Session = Depends(get_session))
]
}
"""
# Validate tax_benefit_model exists
tax_model = session.get(TaxBenefitModel, policy.tax_benefit_model_id)
if not tax_model:
raise HTTPException(status_code=404, detail="Tax benefit model not found")

# Create the policy
db_policy = Policy(name=policy.name, description=policy.description)
db_policy = Policy(
name=policy.name,
description=policy.description,
tax_benefit_model_id=policy.tax_benefit_model_id,
)
session.add(db_policy)
session.flush() # Get the policy ID before adding parameter values

Expand Down Expand Up @@ -112,10 +122,15 @@ def create_policy(policy: PolicyCreate, session: Session = Depends(get_session))


@router.get("/", response_model=List[PolicyRead])
def list_policies(session: Session = Depends(get_session)):
"""List all policies."""
policies = session.exec(select(Policy)).all()
return policies
def list_policies(
tax_benefit_model_id: UUID | None = Query(None, description="Filter by tax benefit model"),
session: Session = Depends(get_session),
):
"""List all policies, optionally filtered by tax benefit model."""
query = select(Policy)
if tax_benefit_model_id:
query = query.where(Policy.tax_benefit_model_id == tax_benefit_model_id)
return session.exec(query).all()


@router.get("/{policy_id}", response_model=PolicyRead)
Expand Down
130 changes: 130 additions & 0 deletions src/policyengine_api/api/user_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""User-policy association endpoints.

Associates users with policies they've saved/created. This enables users to
maintain a list of their policies across sessions without duplicating the
underlying policy data.
"""

from datetime import datetime, timezone
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session, select

from policyengine_api.models import (
Policy,
User,
UserPolicy,
UserPolicyCreate,
UserPolicyRead,
UserPolicyUpdate,
)
from policyengine_api.services.database import get_session

router = APIRouter(prefix="/user-policies", tags=["user-policies"])


@router.post("/", response_model=UserPolicyRead)
def create_user_policy(
user_policy: UserPolicyCreate,
session: Session = Depends(get_session),
):
"""Create a new user-policy association.

Associates a user with a policy, allowing them to save it to their list.
Duplicates are allowed - users can save the same policy multiple times
with different labels (matching FE localStorage behavior).
"""
# Validate user exists
user = session.get(User, user_policy.user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")

# Validate policy exists
policy = session.get(Policy, user_policy.policy_id)
if not policy:
raise HTTPException(status_code=404, detail="Policy not found")

# Create the association (duplicates allowed)
db_user_policy = UserPolicy.model_validate(user_policy)
session.add(db_user_policy)
session.commit()
session.refresh(db_user_policy)
return db_user_policy


@router.get("/", response_model=list[UserPolicyRead])
def list_user_policies(
user_id: UUID = Query(..., description="User ID to filter by"),
tax_benefit_model_id: UUID | None = Query(None, description="Filter by tax benefit model"),
session: Session = Depends(get_session),
):
"""List all policy associations for a user.

Returns all policies saved by the specified user. Optionally filter by tax benefit model.
"""
query = select(UserPolicy).where(UserPolicy.user_id == user_id)

if tax_benefit_model_id:
query = (
query
.join(Policy, UserPolicy.policy_id == Policy.id)
.where(Policy.tax_benefit_model_id == tax_benefit_model_id)
)

user_policies = session.exec(query).all()
return user_policies


@router.get("/{user_policy_id}", response_model=UserPolicyRead)
def get_user_policy(
user_policy_id: UUID,
session: Session = Depends(get_session),
):
"""Get a specific user-policy association by ID."""
user_policy = session.get(UserPolicy, user_policy_id)
if not user_policy:
raise HTTPException(status_code=404, detail="User-policy association not found")
return user_policy


@router.patch("/{user_policy_id}", response_model=UserPolicyRead)
def update_user_policy(
user_policy_id: UUID,
updates: UserPolicyUpdate,
session: Session = Depends(get_session),
):
"""Update a user-policy association (e.g., rename label)."""
user_policy = session.get(UserPolicy, user_policy_id)
if not user_policy:
raise HTTPException(status_code=404, detail="User-policy association not found")

# Apply updates
update_data = updates.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(user_policy, key, value)

# Update timestamp
user_policy.updated_at = datetime.now(timezone.utc)

session.add(user_policy)
session.commit()
session.refresh(user_policy)
return user_policy


@router.delete("/{user_policy_id}", status_code=204)
def delete_user_policy(
user_policy_id: UUID,
session: Session = Depends(get_session),
):
"""Delete a user-policy association.

This only removes the association, not the underlying policy.
"""
user_policy = session.get(UserPolicy, user_policy_id)
if not user_policy:
raise HTTPException(status_code=404, detail="User-policy association not found")

session.delete(user_policy)
session.commit()
10 changes: 10 additions & 0 deletions src/policyengine_api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
TaxBenefitModelVersionRead,
)
from .user import User, UserCreate, UserRead
from .user_policy import (
UserPolicy,
UserPolicyCreate,
UserPolicyRead,
UserPolicyUpdate,
)
from .variable import Variable, VariableCreate, VariableRead

__all__ = [
Expand Down Expand Up @@ -111,6 +117,10 @@
"User",
"UserCreate",
"UserRead",
"UserPolicy",
"UserPolicyCreate",
"UserPolicyRead",
"UserPolicyUpdate",
"Variable",
"VariableCreate",
"VariableRead",
Expand Down
3 changes: 3 additions & 0 deletions src/policyengine_api/models/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

if TYPE_CHECKING:
from .parameter_value import ParameterValue
from .tax_benefit_model import TaxBenefitModel


class PolicyBase(SQLModel):
"""Base policy fields."""

name: str
description: str | None = None
tax_benefit_model_id: UUID = Field(foreign_key="tax_benefit_models.id", index=True)


class Policy(PolicyBase, table=True):
Expand All @@ -26,6 +28,7 @@ class Policy(PolicyBase, table=True):

# Relationships
parameter_values: list["ParameterValue"] = Relationship(back_populates="policy")
tax_benefit_model: "TaxBenefitModel" = Relationship()


class PolicyCreate(PolicyBase):
Expand Down
51 changes: 51 additions & 0 deletions src/policyengine_api/models/user_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from uuid import UUID, uuid4

from sqlmodel import Field, Relationship, SQLModel

if TYPE_CHECKING:
from .policy import Policy
from .user import User


class UserPolicyBase(SQLModel):
"""Base user-policy association fields."""

user_id: UUID = Field(foreign_key="users.id", index=True)
policy_id: UUID = Field(foreign_key="policies.id", index=True)
label: str | None = None


class UserPolicy(UserPolicyBase, table=True):
"""User-policy association database model."""

__tablename__ = "user_policies"

id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))

# Relationships
user: "User" = Relationship()
policy: "Policy" = Relationship()


class UserPolicyCreate(UserPolicyBase):
"""Schema for creating user-policy associations."""

pass


class UserPolicyRead(UserPolicyBase):
"""Schema for reading user-policy associations."""

id: UUID
created_at: datetime
updated_at: datetime


class UserPolicyUpdate(SQLModel):
"""Schema for updating user-policy associations."""

label: str | None = None
Loading