Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rapidfireai/backend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _create_models(
chunk_offset = clone_modify_info.get("chunk_offset", 0) if clone_modify_info else 0

run_id = self.db.create_run(
experiment_id=self.experiment_id,
config_leaf=config_leaf,
status=RunStatus.NEW,
completed_steps=0,
Expand Down
50 changes: 50 additions & 0 deletions rapidfireai/db/migrate_add_experiment_id.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
-- Migration: Add experiment_id column to runs table
-- This migration adds experiment_id foreign key to scope runs to experiments

-- Step 1: Add experiment_id column with a default value
-- For existing runs, set experiment_id to the first experiment (lowest experiment_id)
-- This assumes existing databases only have one experiment
ALTER TABLE runs ADD COLUMN experiment_id INTEGER;

-- Step 2: Populate experiment_id for existing runs
-- Get the first (oldest) experiment_id and assign it to all existing runs
UPDATE runs
SET experiment_id = (SELECT MIN(experiment_id) FROM experiments)
WHERE experiment_id IS NULL;

-- Step 3: Make the column NOT NULL now that it's populated
-- Note: SQLite doesn't support ALTER COLUMN, so we need to recreate the table

-- Create new runs table with experiment_id constraint
CREATE TABLE runs_new (
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
experiment_id INTEGER NOT NULL,
status TEXT NOT NULL,
mlflow_run_id TEXT,
flattened_config TEXT DEFAULT '{}',
config_leaf TEXT DEFAULT '{}',
completed_steps INTEGER DEFAULT 0,
total_steps INTEGER DEFAULT 0,
num_chunks_visited_curr_epoch INTEGER DEFAULT 0,
num_epochs_completed INTEGER DEFAULT 0,
chunk_offset INTEGER DEFAULT 0,
error TEXT DEFAULT '',
source TEXT DEFAULT '',
ended_by TEXT DEFAULT '',
warm_started_from INTEGER DEFAULT NULL,
cloned_from INTEGER DEFAULT NULL,
FOREIGN KEY (experiment_id) REFERENCES experiments (experiment_id)
);

-- Copy data from old table to new table
INSERT INTO runs_new (run_id, experiment_id, status, mlflow_run_id, flattened_config, config_leaf,
completed_steps, total_steps, num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset,
error, source, ended_by, warm_started_from, cloned_from)
SELECT run_id, experiment_id, status, mlflow_run_id, flattened_config, config_leaf,
completed_steps, total_steps, num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset,
error, source, ended_by, warm_started_from, cloned_from
FROM runs;

-- Drop old table and rename new table
DROP TABLE runs;
ALTER TABLE runs_new RENAME TO runs;
190 changes: 115 additions & 75 deletions rapidfireai/db/rf_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def get_experiments_path(self, experiment_name: str) -> Path:
# Runs Table
def create_run(
self,
experiment_id: int,
config_leaf: dict[str, Any],
status: RunStatus,
mlflow_run_id: str | None = None,
Expand All @@ -286,14 +287,15 @@ def create_run(
) -> int:
"""Create a new run"""
query = """
INSERT INTO runs (status, mlflow_run_id, flattened_config, config_leaf,
INSERT INTO runs (experiment_id, status, mlflow_run_id, flattened_config, config_leaf,
completed_steps, total_steps, num_chunks_visited_curr_epoch,
num_epochs_completed, chunk_offset, error, source, ended_by, warm_started_from, cloned_from)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
self.db.execute(
query,
(
experiment_id,
status.value,
mlflow_run_id,
json.dumps(flattened_config) if flattened_config else "{}",
Expand Down Expand Up @@ -373,104 +375,142 @@ def set_run_details(
# Execute the query
self.db.execute(query, tuple(values), commit=True)

def get_run(self, run_id: int) -> dict[str, Any]:
"""Get a run's details"""
query = """
SELECT status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
WHERE run_id = ?
"""
run_details = self.db.execute(query, (run_id,), fetch=True)
def get_run(self, run_id: int, experiment_id: int | None = None) -> dict[str, Any]:
"""Get a run's details, optionally filtering by experiment_id"""
if experiment_id is not None:
query = """
SELECT experiment_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
WHERE run_id = ? AND experiment_id = ?
"""
run_details = self.db.execute(query, (run_id, experiment_id), fetch=True)
else:
query = """
SELECT experiment_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
WHERE run_id = ?
"""
run_details = self.db.execute(query, (run_id,), fetch=True)

if run_details:
run_details = run_details[0]
formatted_details = {
"status": RunStatus(run_details[0]),
"mlflow_run_id": run_details[1],
"flattened_config": json.loads(run_details[2]),
"config_leaf": decode_db_payload(run_details[3]) if run_details[3] and run_details[3] != "{}" else {},
"completed_steps": run_details[4],
"total_steps": run_details[5],
"num_chunks_visited_curr_epoch": run_details[6],
"num_epochs_completed": run_details[7],
"chunk_offset": run_details[8],
"error": run_details[9],
"source": RunSource(run_details[10]) if run_details[10] else None,
"ended_by": RunEndedBy(run_details[11]) if run_details[11] else None,
"warm_started_from": run_details[12],
"cloned_from": run_details[13],
"experiment_id": run_details[0],
"status": RunStatus(run_details[1]),
"mlflow_run_id": run_details[2],
"flattened_config": json.loads(run_details[3]),
"config_leaf": decode_db_payload(run_details[4]) if run_details[4] and run_details[4] != "{}" else {},
"completed_steps": run_details[5],
"total_steps": run_details[6],
"num_chunks_visited_curr_epoch": run_details[7],
"num_epochs_completed": run_details[8],
"chunk_offset": run_details[9],
"error": run_details[10],
"source": RunSource(run_details[11]) if run_details[11] else None,
"ended_by": RunEndedBy(run_details[12]) if run_details[12] else None,
"warm_started_from": run_details[13],
"cloned_from": run_details[14],
}
return formatted_details
raise DBException("No run found")

def get_runs_by_status(self, statuses: list[RunStatus]) -> dict[int, dict[str, Any]]:
"""Get all runs by statuses"""
def get_runs_by_status(self, statuses: list[RunStatus], experiment_id: int | None = None) -> dict[int, dict[str, Any]]:
"""Get all runs by statuses, optionally filtering by experiment_id"""
if not statuses:
return {}

# Create placeholders for SQL IN clause
placeholders = ",".join(["?"] * len(statuses))
query = f"""
SELECT run_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
WHERE status IN ({placeholders})
"""
# Extract status values for the query parameters
status_values = [status.value for status in statuses]
run_details = self.db.execute(query, status_values, fetch=True)

if experiment_id is not None:
query = f"""
SELECT run_id, experiment_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
WHERE status IN ({placeholders}) AND experiment_id = ?
"""
# Extract status values and append experiment_id for the query parameters
status_values = [status.value for status in statuses]
status_values.append(experiment_id)
run_details = self.db.execute(query, status_values, fetch=True)
else:
query = f"""
SELECT run_id, experiment_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
WHERE status IN ({placeholders})
"""
# Extract status values for the query parameters
status_values = [status.value for status in statuses]
run_details = self.db.execute(query, status_values, fetch=True)

formatted_details: dict[int, dict[str, Any]] = {}
if run_details:
for run in run_details:
formatted_details[run[0]] = {
"status": RunStatus(run[1]),
"mlflow_run_id": run[2],
"flattened_config": json.loads(run[3]),
"config_leaf": decode_db_payload(run[4]) if run[4] and run[4] != "{}" else {},
"completed_steps": run[5],
"total_steps": run[6],
"num_chunks_visited_curr_epoch": run[7],
"num_epochs_completed": run[8],
"chunk_offset": run[9],
"error": run[10],
"source": RunSource(run[11]) if run[11] else None,
"ended_by": RunEndedBy(run[12]) if run[12] else None,
"warm_started_from": run[13],
"cloned_from": run[14],
"experiment_id": run[1],
"status": RunStatus(run[2]),
"mlflow_run_id": run[3],
"flattened_config": json.loads(run[4]),
"config_leaf": decode_db_payload(run[5]) if run[5] and run[5] != "{}" else {},
"completed_steps": run[6],
"total_steps": run[7],
"num_chunks_visited_curr_epoch": run[8],
"num_epochs_completed": run[9],
"chunk_offset": run[10],
"error": run[11],
"source": RunSource(run[12]) if run[12] else None,
"ended_by": RunEndedBy(run[13]) if run[13] else None,
"warm_started_from": run[14],
"cloned_from": run[15],
}
return formatted_details

def get_all_runs(self) -> dict[int, dict[str, Any]]:
"""Get all runs for UI display (ignore all complex fields)"""
query = """
SELECT run_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
"""
run_details = self.db.execute(query, fetch=True)
def get_all_runs(self, experiment_id: int | None = None) -> dict[int, dict[str, Any]]:
"""Get all runs for UI display, optionally filtering by experiment_id"""
if experiment_id is not None:
query = """
SELECT run_id, experiment_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
WHERE experiment_id = ?
"""
run_details = self.db.execute(query, (experiment_id,), fetch=True)
else:
query = """
SELECT run_id, experiment_id, status, mlflow_run_id, flattened_config, config_leaf, completed_steps, total_steps,
num_chunks_visited_curr_epoch, num_epochs_completed, chunk_offset, error, source, ended_by,
warm_started_from, cloned_from
FROM runs
"""
run_details = self.db.execute(query, fetch=True)

formatted_details: dict[int, dict[str, Any]] = {}
if run_details:
for run in run_details:
formatted_details[run[0]] = {
"status": RunStatus(run[1]),
"mlflow_run_id": run[2],
"flattened_config": json.loads(run[3]),
"config_leaf": decode_db_payload(run[4]) if run[4] and run[4] != "{}" else {},
"completed_steps": run[5],
"total_steps": run[6],
"num_chunks_visited_curr_epoch": run[7],
"num_epochs_completed": run[8],
"chunk_offset": run[9],
"error": run[10],
"source": RunSource(run[11]) if run[11] else None,
"ended_by": RunEndedBy(run[12]) if run[12] else None,
"warm_started_from": run[13],
"cloned_from": run[14],
"experiment_id": run[1],
"status": RunStatus(run[2]),
"mlflow_run_id": run[3],
"flattened_config": json.loads(run[4]),
"config_leaf": decode_db_payload(run[5]) if run[5] and run[5] != "{}" else {},
"completed_steps": run[6],
"total_steps": run[7],
"num_chunks_visited_curr_epoch": run[8],
"num_epochs_completed": run[9],
"chunk_offset": run[10],
"error": run[11],
"source": RunSource(run[12]) if run[12] else None,
"ended_by": RunEndedBy(run[13]) if run[13] else None,
"warm_started_from": run[14],
"cloned_from": run[15],
}
return formatted_details

Expand Down
4 changes: 3 additions & 1 deletion rapidfireai/db/tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ CREATE TABLE IF NOT EXISTS experiments (
-- Runs table
CREATE TABLE IF NOT EXISTS runs (
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
experiment_id INTEGER NOT NULL,
status TEXT NOT NULL,
mlflow_run_id TEXT,
flattened_config TEXT DEFAULT '{}',
Expand All @@ -25,7 +26,8 @@ CREATE TABLE IF NOT EXISTS runs (
source TEXT DEFAULT '',
ended_by TEXT DEFAULT '',
warm_started_from INTEGER DEFAULT NULL,
cloned_from INTEGER DEFAULT NULL
cloned_from INTEGER DEFAULT NULL,
FOREIGN KEY (experiment_id) REFERENCES experiments (experiment_id)
);

-- Interactive Control table
Expand Down
18 changes: 16 additions & 2 deletions rapidfireai/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,21 @@ def get_run(self) -> tuple[Response, int]:
"""Get a run for the UI"""
try:
data = request.get_json()
result = self.db.get_run(data["run_id"])
run_id = data["run_id"]

# Get experiment_id from experiment_name if provided, otherwise use running experiment
experiment_id = None
if "experiment_name" in data and data["experiment_name"]:
# Look up experiment_id from experiment_name
exp_info = self.db.get_running_experiment() # Fallback
# TODO: Add method to get experiment by name if needed
# For now, we'll use the provided experiment_name to get its ID
else:
# Use the running experiment by default
exp_info = self.db.get_running_experiment()
experiment_id = exp_info["experiment_id"]

result = self.db.get_run(run_id, experiment_id=experiment_id)
if not result:
return jsonify({"error": "Run not found"}), 404

Expand All @@ -157,7 +171,7 @@ def get_run(self) -> tuple[Response, int]:
result["config_leaf"].pop("reward_funcs", None)

safe_result = {
"run_id": data["run_id"],
"run_id": run_id,
"status": result["status"].value,
"mlflow_run_id": result["mlflow_run_id"],
"config": result["config_leaf"],
Expand Down