|
21 | 21 | import sqlalchemy.exc |
22 | 22 | from loguru import logger |
23 | 23 | from sqlalchemy import inspect, text |
| 24 | +from sqlalchemy.engine.interfaces import ReflectedColumn |
24 | 25 | from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine |
| 26 | +from sqlmodel import SQLModel |
25 | 27 |
|
26 | 28 | from tux.shared.config import CONFIG |
27 | 29 |
|
@@ -334,31 +336,48 @@ async def validate_schema(self) -> dict[str, Any]: |
334 | 336 |
|
335 | 337 | try: |
336 | 338 | # Get database inspector to reflect current schema |
337 | | - async with self.engine.begin() as conn: |
| 339 | + # Type checker doesn't know engine is not None after is_connected() check |
| 340 | + assert self._engine is not None, "Engine should not be None after connection check" |
| 341 | + async with self._engine.begin() as conn: |
338 | 342 | inspector = await conn.run_sync(lambda sync_conn: inspect(sync_conn)) |
339 | 343 |
|
340 | 344 | # Check if required tables exist |
341 | 345 | existing_tables = await conn.run_sync(lambda sync_conn: inspector.get_table_names()) |
342 | | - required_tables = {"guild", "guild_config", "cases"} |
| 346 | + # Get table names from SQLModel metadata (models with table=True) |
| 347 | + required_tables = set(SQLModel.metadata.tables.keys()) |
343 | 348 |
|
344 | | - missing_tables = required_tables - set(existing_tables) |
345 | | - if missing_tables: |
| 349 | + if missing_tables := required_tables - set(existing_tables): |
346 | 350 | return { |
347 | 351 | "status": "invalid", |
348 | 352 | "error": f"Missing tables: {', '.join(missing_tables)}. Run 'uv run db reset' to fix.", |
349 | 353 | } |
350 | 354 |
|
351 | | - # Check if critical columns exist |
| 355 | + # Helper function to get columns for a table |
| 356 | + def get_table_columns(sync_conn: Any, table_name: str) -> list[ReflectedColumn]: |
| 357 | + return inspector.get_columns(table_name) |
| 358 | + |
| 359 | + # Check that all model columns exist in database (1-to-1 validation) |
| 360 | + missing_columns: list[str] = [] |
352 | 361 | for table_name in required_tables: |
353 | | - columns = await conn.run_sync(lambda sync_conn, table=table_name: inspector.get_columns(table)) |
354 | | - column_names = {col["name"] for col in columns} |
355 | | - |
356 | | - # Check for critical columns that have caused issues before |
357 | | - if table_name == "cases" and "mod_log_message_id" not in column_names: |
358 | | - return { |
359 | | - "status": "invalid", |
360 | | - "error": "Missing 'mod_log_message_id' column in 'cases' table. Run 'uv run db reset' to fix.", |
361 | | - } |
| 362 | + # Get columns from database |
| 363 | + columns = await conn.run_sync(get_table_columns, table_name) |
| 364 | + db_column_names = {col["name"] for col in columns} |
| 365 | + |
| 366 | + # Get columns from model metadata |
| 367 | + if table_name in SQLModel.metadata.tables: |
| 368 | + table_metadata = SQLModel.metadata.tables[table_name] |
| 369 | + model_column_names = {col.name for col in table_metadata.columns} |
| 370 | + |
| 371 | + # Find missing columns |
| 372 | + missing_for_table = model_column_names - db_column_names |
| 373 | + if missing_for_table: |
| 374 | + missing_columns.extend([f"{table_name}.{col}" for col in missing_for_table]) |
| 375 | + |
| 376 | + if missing_columns: |
| 377 | + return { |
| 378 | + "status": "invalid", |
| 379 | + "error": f"Missing columns: {', '.join(missing_columns)}. Run 'uv run db reset' to fix.", |
| 380 | + } |
362 | 381 |
|
363 | 382 | return {"status": "valid", "mode": "async"} |
364 | 383 |
|
|
0 commit comments