Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
# The name of the deployed pipeline in Databricks. Must match directly.
PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline"

VALID_BRONZE_FILE_RE = re.compile(
r"^[a-z0-9]+pdp_[a-z0-9]+_(course_level_)?ar_.*\.csv$",
re.IGNORECASE,
)


class DatabricksInferenceRunRequest(BaseModel):
"""Databricks parameters for an inference run."""
Expand Down Expand Up @@ -181,6 +186,96 @@ def setup_new_inst(self, inst_name: str) -> None:
exist_ok=True,
)

def list_bronze_volume_csvs(self, inst_name: str) -> list[str]:
"""List `.csv` files directly under the institution's bronze volume root."""
if not databricks_vars.get("DATABRICKS_HOST_URL") or not databricks_vars.get(
"CATALOG_NAME"
):
raise ValueError("Databricks integration not configured.")
if not gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"):
raise ValueError("GCP service account email not configured.")

try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars.get("DATABRICKS_HOST_URL"),
gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"),
)
raise ValueError(f"Workspace client creation failed: {e}")

db_inst_name = databricksify_inst_name(inst_name)
volume_root = (
f"/Volumes/{databricks_vars['CATALOG_NAME']}/"
f"{db_inst_name}_bronze/bronze_volume"
)

try:
entries = list(w.dbfs.list(f"dbfs:{volume_root}") or [])
except Exception as e:
LOGGER.exception("Failed to list bronze volume directory: %s", volume_root)
raise ValueError(f"Failed to list bronze volume directory: {e}")

csvs: list[str] = []
for entry in entries:
entry_path = getattr(entry, "path", None)
is_dir = getattr(entry, "is_dir", False)
if not entry_path or is_dir:
continue
basename = os.path.basename(str(entry_path))
if not VALID_BRONZE_FILE_RE.match(basename):
continue
csvs.append(basename)
csvs.sort()
return csvs

def download_bronze_volume_file(self, inst_name: str, file_name: str) -> Any:
"""Download a file from the institution's bronze volume root and return a byte stream."""
if "/" in file_name:
raise ValueError("file_name must not contain '/'.")
if not VALID_BRONZE_FILE_RE.match(file_name):
raise ValueError("Invalid bronze dataset filename.")
if not databricks_vars.get("DATABRICKS_HOST_URL") or not databricks_vars.get(
"CATALOG_NAME"
):
raise ValueError("Databricks integration not configured.")
if not gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"):
raise ValueError("GCP service account email not configured.")

try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars.get("DATABRICKS_HOST_URL"),
gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"),
)
raise ValueError(f"Workspace client creation failed: {e}")

db_inst_name = databricksify_inst_name(inst_name)
volume_path = (
f"/Volumes/{databricks_vars['CATALOG_NAME']}/"
f"{db_inst_name}_bronze/bronze_volume/{file_name}"
)

try:
response = w.files.download(volume_path)
except Exception as e:
LOGGER.exception("Failed to download from %s", volume_path)
raise ValueError(f"Failed to download bronze dataset: {e}")

stream = getattr(response, "contents", None)
if stream is None:
raise ValueError("Databricks download returned no contents.")
return stream

# Note that for each unique PIPELINE, we'll need a new function, this is by nature of how unique pipelines
# may have unique parameters and would have a unique name (i.e. the name field specified in w.jobs.list()). But any run of a given pipeline (even across institutions) can use the same function.
# E.g. there is one PDP inference pipeline, so one PDP inference function here.
Expand Down
24 changes: 23 additions & 1 deletion src/webapp/gcsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
import io
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, IO

import pandas as pd
from pydantic import BaseModel
Expand Down Expand Up @@ -120,6 +120,28 @@ def generate_download_signed_url(self, bucket_name: str, blob_name: str) -> Any:
)
return url

def upload_unvalidated_csv_from_file(
self, bucket_name: str, file_name: str, file_obj: IO[bytes]
) -> None:
"""Upload a CSV into unvalidated/ while enforcing no-overwrite semantics."""
if not file_name or not file_name.strip():
raise ValueError("file_name is required and must be non-empty.")
if "/" in file_name:
raise ValueError("file_name must not contain '/'.")

client = storage.Client()
bucket = client.bucket(bucket_name)
if not bucket.exists():
raise ValueError("Storage bucket not found.")

for prefix in ("unvalidated/", "validated/"):
blob = bucket.blob(prefix + file_name)
if blob.exists():
raise ValueError("File already exists.")

blob = bucket.blob("unvalidated/" + file_name)
blob.upload_from_file(file_obj, content_type="text/csv")

def delete_bucket(self, bucket_name: str) -> None:
"""Delete a given bucket."""
storage_client = storage.Client()
Expand Down
139 changes: 139 additions & 0 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
from sqlalchemy.exc import IntegrityError
import re
import requests
from ..validation import HardValidationError
from ..validation_error_formatter import format_validation_error
import pandas as pd
Expand Down Expand Up @@ -179,6 +180,18 @@ class ValidationResult(BaseModel):
source: str


class BronzeImportRequest(BaseModel):
"""Request to import a dataset from the institution's bronze volume into GCS."""

name: str


class BronzeImportResponse(BaseModel):
"""Response for bronze import request."""

file_name: str


class DataOverview(BaseModel):
"""All data for a given institution (batches and files)."""

Expand Down Expand Up @@ -1659,6 +1672,132 @@ def get_upload_url(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/input/bronze-datasets", response_model=list[str])
def list_bronze_datasets(
inst_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
) -> Any:
"""List `.csv` files directly under the institution's Databricks bronze volume root."""
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)

inst = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.scalar_one_or_none()
)
if inst is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)

try:
return databricks_control.list_bronze_volume_csvs(inst.name)
except ValueError as ve:
msg = str(ve)
if "not configured" in msg.lower():
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg
)


@router.post(
"/{inst_id}/input/upload-from-volume-to-gcs-bucket",
response_model=BronzeImportResponse,
)
def upload_from_volume_to_gcs_bucket(
Copy link
Contributor

@vishpillai123 vishpillai123 Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So frontend flow will be... FE first list available datasets through "/{inst_id}/input/bronze-datasets", then user selects a CSV, then clicks upload or something (which then makes a call to "/{inst_id}/input/upload-from-volume-to-gcs-bucket") and this creates an unvalidated batch? Then we proceed with validation to create a batch correct?

inst_id: str,
req: BronzeImportRequest,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
storage_control: Annotated[StorageControl, Depends(StorageControl)],
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
) -> Any:
"""Import a selected dataset from the institution's bronze volume into GCS unvalidated/."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a user select multiple datasets? For example a cohort and a course file?

has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)

inst = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.scalar_one_or_none()
)
if inst is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)

requested_name = (req.name or "").strip()
if not requested_name:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Dataset name is required.",
)
if "/" in requested_name:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Dataset name can't contain '/'.",
)

# Ensure this is actually present in the bronze root (and matches naming rules).
try:
available = databricks_control.list_bronze_volume_csvs(inst.name)
except ValueError as ve:
msg = str(ve)
if "not configured" in msg.lower():
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg
)

available_map = {x.lower(): x for x in available}
file_name = available_map.get(requested_name.lower())
if not file_name:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Bronze dataset not found.",
)

stream = None
try:
stream = databricks_control.download_bronze_volume_file(inst.name, file_name)
upload_url = storage_control.generate_upload_signed_url(
get_external_bucket_name(inst_id), file_name
)
resp = requests.put(
upload_url,
data=stream,
headers={"Content-Type": "text/csv"},
timeout=600,
)
resp.raise_for_status()
except ValueError as ve:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
except requests.RequestException as rexc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to upload dataset to GCS: {rexc}",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unexpected error importing dataset: {e}",
)
finally:
if stream is not None and hasattr(stream, "close"):
try:
stream.close()
except Exception:
pass

return {"file_name": file_name}


@router.post("/{inst_id}/add-custom-school-job/{job_run_id}")
def add_custom_school_job(
inst_id: str,
Expand Down
Loading
Loading