Skip to content

Commit 79f0372

Browse files
committed
feat: enhance database schema validation in DatabaseService
- Update schema validation logic to dynamically check for required tables and columns based on SQLModel metadata. - Implement a helper function to retrieve table columns, ensuring a 1-to-1 validation between model definitions and database structure. - Improve error messaging for missing tables and columns, guiding users on how to resolve issues effectively.
1 parent 88fd23b commit 79f0372

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

src/tux/database/service.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import sqlalchemy.exc
2222
from loguru import logger
2323
from sqlalchemy import inspect, text
24+
from sqlalchemy.engine.interfaces import ReflectedColumn
2425
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
26+
from sqlmodel import SQLModel
2527

2628
from tux.shared.config import CONFIG
2729

@@ -334,31 +336,48 @@ async def validate_schema(self) -> dict[str, Any]:
334336

335337
try:
336338
# Get database inspector to reflect current schema
337-
async with self.engine.begin() as conn:
339+
# Type checker doesn't know engine is not None after is_connected() check
340+
assert self._engine is not None, "Engine should not be None after connection check"
341+
async with self._engine.begin() as conn:
338342
inspector = await conn.run_sync(lambda sync_conn: inspect(sync_conn))
339343

340344
# Check if required tables exist
341345
existing_tables = await conn.run_sync(lambda sync_conn: inspector.get_table_names())
342-
required_tables = {"guild", "guild_config", "cases"}
346+
# Get table names from SQLModel metadata (models with table=True)
347+
required_tables = set(SQLModel.metadata.tables.keys())
343348

344-
missing_tables = required_tables - set(existing_tables)
345-
if missing_tables:
349+
if missing_tables := required_tables - set(existing_tables):
346350
return {
347351
"status": "invalid",
348352
"error": f"Missing tables: {', '.join(missing_tables)}. Run 'uv run db reset' to fix.",
349353
}
350354

351-
# Check if critical columns exist
355+
# Helper function to get columns for a table
356+
def get_table_columns(sync_conn: Any, table_name: str) -> list[ReflectedColumn]:
357+
return inspector.get_columns(table_name)
358+
359+
# Check that all model columns exist in database (1-to-1 validation)
360+
missing_columns: list[str] = []
352361
for table_name in required_tables:
353-
columns = await conn.run_sync(lambda sync_conn, table=table_name: inspector.get_columns(table))
354-
column_names = {col["name"] for col in columns}
355-
356-
# Check for critical columns that have caused issues before
357-
if table_name == "cases" and "mod_log_message_id" not in column_names:
358-
return {
359-
"status": "invalid",
360-
"error": "Missing 'mod_log_message_id' column in 'cases' table. Run 'uv run db reset' to fix.",
361-
}
362+
# Get columns from database
363+
columns = await conn.run_sync(get_table_columns, table_name)
364+
db_column_names = {col["name"] for col in columns}
365+
366+
# Get columns from model metadata
367+
if table_name in SQLModel.metadata.tables:
368+
table_metadata = SQLModel.metadata.tables[table_name]
369+
model_column_names = {col.name for col in table_metadata.columns}
370+
371+
# Find missing columns
372+
missing_for_table = model_column_names - db_column_names
373+
if missing_for_table:
374+
missing_columns.extend([f"{table_name}.{col}" for col in missing_for_table])
375+
376+
if missing_columns:
377+
return {
378+
"status": "invalid",
379+
"error": f"Missing columns: {', '.join(missing_columns)}. Run 'uv run db reset' to fix.",
380+
}
362381

363382
return {"status": "valid", "mode": "async"}
364383

0 commit comments

Comments
 (0)