Skip to content
Draft
3 changes: 3 additions & 0 deletions src/webapp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
"GCP_SERVICE_ACCOUNT_EMAIL": "",
}

# ENV -> Databricks volume schema (used for /Volumes/{schema}/... paths).
ENV_TO_VOLUME_SCHEMA = {"DEV": "dev_sst_02", "STAGING": "staging_sst_01"}

# databricks vars needed for databricks integration
databricks_vars = {
# SECRET.
Expand Down
190 changes: 166 additions & 24 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from google.cloud import storage
from google.api_core import exceptions as gcs_errors
from .validation_extension import generate_extension_schema
from .config import databricks_vars, gcs_vars
from .config import ENV_TO_VOLUME_SCHEMA, databricks_vars, env_vars, gcs_vars
from .utilities import databricksify_inst_name, SchemaType
from typing import List, Any, Dict, Optional
from fastapi import HTTPException
Expand Down Expand Up @@ -48,6 +48,8 @@ class DatabricksInferenceRunRequest(BaseModel):
# The email where notifications will get sent.
email: str
gcp_external_bucket_name: str
# Optional term filter (e.g. cohort labels); serialized as JSON for job params when set. Used for cohort/graduation models.
term_filter: list[str] | None = None


class DatabricksInferenceRunResponse(BaseModel):
Expand Down Expand Up @@ -83,6 +85,78 @@ def _sha256_json(obj: Any) -> str:
).hexdigest()


def _parse_config_toml_to_selection(raw: bytes) -> dict | None:
"""Parse TOML bytes and return the [preprocessing.selection] section, or None."""
try:
try:
import tomllib

try:
data = tomllib.loads(raw)
except TypeError:
data = tomllib.loads(raw.decode("utf-8"))
except ImportError:
import tomli as tomllib

data = tomllib.loads(raw.decode("utf-8"))
except (Exception, TypeError):
return None
if not isinstance(data, dict):
return None
preprocessing = data.get("preprocessing")
if not isinstance(preprocessing, dict):
return None
selection = preprocessing.get("selection")
if not isinstance(selection, dict):
return None
return selection


def _find_selection_in_toml_under(
w: WorkspaceClient,
directory_path: str,
inst_name: str,
) -> dict | None:
"""List directory recursively; find first .toml file with [preprocessing.selection]."""
try:
entries = list(w.files.list_directory_contents(directory_path))
except Exception as e:
LOGGER.debug(
"read_volume_training_config: could not list %s for %s: %s",
directory_path,
inst_name,
e,
)
return None
for entry in entries:
if not entry.path:
continue
if entry.is_directory:
selection = _find_selection_in_toml_under(w, entry.path, inst_name)
if selection is not None:
return selection
continue
if entry.name and entry.name.lower().endswith(".toml"):
try:
response = w.files.download(entry.path)
if response.contents is None:
continue
raw = response.contents.read()
except Exception as e:
LOGGER.debug(
"read_volume_training_config: could not read %s for %s: %s",
entry.path,
inst_name,
e,
)
continue
raw_bytes = raw if isinstance(raw, bytes) else raw.encode("utf-8")
selection = _parse_config_toml_to_selection(raw_bytes)
if selection is not None:
return selection
return None


L1_RESP_CACHE_TTL = int("600") # seconds
L1_VER_CACHE_TTL = int("3600") # seconds
L1_RESP_CACHE: Any = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL)
Expand Down Expand Up @@ -225,44 +299,44 @@ def run_pdp_inference(
f"run_pdp_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'."
)
job_id = job.job_id
LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}")
LOGGER.info("Resolved job ID for '%s': %s", pipeline_type, job_id)
except Exception as e:
LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.")
LOGGER.exception("Job lookup failed for '%s': %s", pipeline_type, e)
raise ValueError(f"run_pdp_inference(): Failed to find job: {e}")

job_params: Dict[str, str] = {
"cohort_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.STUDENT
),
"course_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.COURSE
),
"databricks_institution_name": db_inst_name,
"DB_workspace": databricks_vars[
"DATABRICKS_WORKSPACE"
], # is this value the same PER environ? dev/staging/prod
"gcp_bucket_name": req.gcp_external_bucket_name,
"model_name": req.model_name,
"notification_email": req.email,
}
if req.term_filter is not None:
job_params["term_filter"] = json.dumps(req.term_filter)
try:
run_job: Any = w.jobs.run_now(
job_id,
job_parameters={
"cohort_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.STUDENT
),
"course_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.COURSE
),
"databricks_institution_name": db_inst_name,
"DB_workspace": databricks_vars[
"DATABRICKS_WORKSPACE"
], # is this value the same PER environ? dev/staging/prod
"gcp_bucket_name": req.gcp_external_bucket_name,
"model_name": req.model_name,
"notification_email": req.email,
},
job_parameters=job_params,
)
LOGGER.info(
f"Successfully triggered job run. Run ID: {run_job.response.run_id}"
"Successfully triggered job run. Run ID: %s", run_job.response.run_id
)
except Exception as e:
LOGGER.exception("Failed to run the PDP inference job.")
LOGGER.exception("Failed to run the PDP inference job: %s", e)
raise ValueError(f"run_pdp_inference(): Job could not be run: {e}")

if not run_job.response or run_job.response.run_id is None:
raise ValueError("run_pdp_inference(): Job did not return a valid run_id.")

run_id = run_job.response.run_id
LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}")

return DatabricksInferenceRunResponse(job_run_id=run_id)
return DatabricksInferenceRunResponse(job_run_id=run_job.response.run_id)

def delete_inst(self, inst_name: str) -> None:
"""Cleanup tasks required on the Databricks side to delete an institution."""
Expand Down Expand Up @@ -530,6 +604,74 @@ def fetch_model_version(

return latest_version

def read_volume_training_config(
self, inst_name: str, model_run_id: str
) -> dict | None:
"""Read training/preprocessing config from the model's training run in the silver volume.

Looks for any .toml file under:
/Volumes/{env_schema}/{slug}_silver/silver_volume/{model_run_id}/

Uses the first .toml found (any subfolder) that contains [preprocessing.selection].
inst_name is the institution display name (e.g. from inst.name). The path slug is
derived with databricksify_inst_name(inst_name). model_run_id is the training run
identifier (e.g. 0b2e206732ce48f6b644149090c9614a). env_schema is derived from ENV
(see config.startup_env_vars and ENV_FILE_PATH). Allowed values: DEV -> dev_sst_02,
STAGING -> staging_sst_01; other values (e.g. LOCAL, PROD) return None.
Returns the [preprocessing.selection] section only (dict with student_criteria, etc.)
for use by latest-inference-cohort and related logic.

Returns that section dict, or None if no suitable file or section is found.
"""
if not inst_name or not str(inst_name).strip():
LOGGER.warning(
"read_volume_training_config: empty inst_name; cannot build volume path.",
)
return None
model_run_id_clean = str(model_run_id).strip() if model_run_id else ""
if not model_run_id_clean:
LOGGER.warning(
"read_volume_training_config: empty model_run_id; cannot build volume path.",
)
return None
env = str(env_vars.get("ENV", "")).strip().upper()
if env not in ENV_TO_VOLUME_SCHEMA:
LOGGER.warning(
"read_volume_training_config: ENV %r not in %s; cannot read config for %s",
env_vars.get("ENV"),
list(ENV_TO_VOLUME_SCHEMA),
inst_name,
)
return None
env_schema = ENV_TO_VOLUME_SCHEMA[env]
try:
db_inst_name = databricksify_inst_name(inst_name)
except ValueError as e:
LOGGER.warning(
"read_volume_training_config: cannot databricksify inst_name %r: %s",
inst_name,
e,
)
return None
directory_path = (
f"/Volumes/{env_schema}/{db_inst_name}_silver/silver_volume/"
f"{model_run_id_clean}"
)
try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
LOGGER.exception(
"read_volume_training_config: WorkspaceClient failed for %s: %s",
inst_name,
e,
)
return None
selection = _find_selection_in_toml_under(w, directory_path, inst_name)
return selection

def delete_model(self, catalog_name: str, inst_name: str, model_name: str) -> None:
schema = databricksify_inst_name(inst_name)
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"
Expand Down
Loading
Loading