Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions src/policyengine_api/api/change_aggregates.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,101 @@
"""Change aggregate endpoints.

Change aggregates compare statistics between baseline and reform simulations
(e.g. change in tax revenue, change in poverty rate). These are typically
created automatically when processing economic impact analyses.
(e.g. change in tax revenue, change in poverty rate). Computation is triggered
on Modal.
"""

from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
import logfire
import modal
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from sqlmodel import Session, select

from policyengine_api.models import (
ChangeAggregate,
ChangeAggregateCreate,
ChangeAggregateRead,
Simulation,
TaxBenefitModel,
TaxBenefitModelVersion,
)
from policyengine_api.services.database import get_session

router = APIRouter(prefix="/outputs/change-aggregates", tags=["change-aggregates"])


def _get_traceparent() -> str | None:
"""Get W3C traceparent from current span context."""
try:
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)

carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(
carrier, trace.set_span_in_context(trace.get_current_span())
)
return carrier.get("traceparent")
except Exception:
return None


def _trigger_change_aggregate_computation(
change_aggregate_id: str, baseline_simulation_id: UUID, session: Session
) -> None:
"""Trigger change aggregate computation on Modal."""
# Look up simulation to determine UK/US
simulation = session.get(Simulation, baseline_simulation_id)
if not simulation:
logfire.error("Simulation not found", simulation_id=str(baseline_simulation_id))
return

model_version = session.get(
TaxBenefitModelVersion, simulation.tax_benefit_model_version_id
)
if not model_version:
logfire.error(
"Model version not found",
version_id=str(simulation.tax_benefit_model_version_id),
)
return

model = session.get(TaxBenefitModel, model_version.tax_benefit_model_id)
if not model:
logfire.error(
"Model not found", model_id=str(model_version.tax_benefit_model_id)
)
return

traceparent = _get_traceparent()

if model.name == "uk" or model.name == "policyengine_uk":
fn = modal.Function.from_name("policyengine", "compute_change_aggregate_uk")
else:
fn = modal.Function.from_name("policyengine", "compute_change_aggregate_us")

fn.spawn(change_aggregate_id=change_aggregate_id, traceparent=traceparent)
logfire.info(
"Triggered change aggregate computation",
change_aggregate_id=change_aggregate_id,
model=model.name,
)


@router.post("/", response_model=List[ChangeAggregateRead])
def create_change_aggregates(
outputs: List[ChangeAggregateCreate], session: Session = Depends(get_session)
outputs: List[ChangeAggregateCreate],
background_tasks: BackgroundTasks,
session: Session = Depends(get_session),
):
"""Create change aggregate specifications comparing baseline vs reform.
"""Create change aggregate specifications and trigger computation.

Change aggregates compute the difference in statistics between two simulations.
Computation happens asynchronously on Modal. Poll GET /outputs/change-aggregates/{id}
until status="completed" to get results.
"""
db_outputs = []
for output in outputs:
Expand All @@ -37,6 +105,13 @@ def create_change_aggregates(
session.commit()
for db_output in db_outputs:
session.refresh(db_output)

# Trigger computation for each change aggregate
for db_output in db_outputs:
_trigger_change_aggregate_computation(
str(db_output.id), db_output.baseline_simulation_id, session
)

return db_outputs


Expand Down
84 changes: 79 additions & 5 deletions src/policyengine_api/api/outputs.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,100 @@
"""Aggregate output endpoints.

Aggregates are computed statistics from simulations (e.g. total tax revenue,
benefit spending, poverty rates). These are typically created automatically
by the worker when processing economic impact analyses.
benefit spending, poverty rates). Computation is triggered on Modal.
"""

from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
import logfire
import modal
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from sqlmodel import Session, select

from policyengine_api.models import (
AggregateOutput,
AggregateOutputCreate,
AggregateOutputRead,
Simulation,
TaxBenefitModel,
TaxBenefitModelVersion,
)
from policyengine_api.services.database import get_session

router = APIRouter(prefix="/outputs/aggregates", tags=["aggregates"])


def _get_traceparent() -> str | None:
"""Get W3C traceparent from current span context."""
try:
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)

carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(
carrier, trace.set_span_in_context(trace.get_current_span())
)
return carrier.get("traceparent")
except Exception:
return None


def _trigger_aggregate_computation(
aggregate_id: str, simulation_id: UUID, session: Session
) -> None:
"""Trigger aggregate computation on Modal."""
# Look up simulation to determine UK/US
simulation = session.get(Simulation, simulation_id)
if not simulation:
logfire.error("Simulation not found", simulation_id=str(simulation_id))
return

model_version = session.get(
TaxBenefitModelVersion, simulation.tax_benefit_model_version_id
)
if not model_version:
logfire.error(
"Model version not found",
version_id=str(simulation.tax_benefit_model_version_id),
)
return

model = session.get(TaxBenefitModel, model_version.tax_benefit_model_id)
if not model:
logfire.error(
"Model not found", model_id=str(model_version.tax_benefit_model_id)
)
return

traceparent = _get_traceparent()

if model.name == "uk" or model.name == "policyengine_uk":
fn = modal.Function.from_name("policyengine", "compute_aggregate_uk")
else:
fn = modal.Function.from_name("policyengine", "compute_aggregate_us")

fn.spawn(aggregate_id=aggregate_id, traceparent=traceparent)
logfire.info(
"Triggered aggregate computation",
aggregate_id=aggregate_id,
model=model.name,
)


@router.post("/", response_model=List[AggregateOutputRead])
def create_aggregate_outputs(
outputs: List[AggregateOutputCreate], session: Session = Depends(get_session)
outputs: List[AggregateOutputCreate],
background_tasks: BackgroundTasks,
session: Session = Depends(get_session),
):
"""Create aggregate output specifications for the worker to compute.
"""Create aggregate output specifications and trigger computation.

Aggregates are statistics like sums, means, or counts of simulation variables.
Computation happens asynchronously on Modal. Poll GET /outputs/aggregates/{id}
until status="completed" to get results.
"""
db_outputs = []
for output in outputs:
Expand All @@ -37,6 +104,13 @@ def create_aggregate_outputs(
session.commit()
for db_output in db_outputs:
session.refresh(db_output)

# Trigger computation for each aggregate
for db_output in db_outputs:
_trigger_aggregate_computation(
str(db_output.id), db_output.simulation_id, session
)

return db_outputs


Expand Down
Loading