|
24 | 24 | import os |
25 | 25 | import re |
26 | 26 | import subprocess |
| 27 | +import types |
| 28 | +import typing |
27 | 29 | from collections import defaultdict |
28 | 30 | from concurrent.futures import ThreadPoolExecutor, as_completed |
29 | 31 | from datetime import datetime, timedelta, timezone |
|
42 | 44 | from rich.table import Table |
43 | 45 |
|
44 | 46 | UTC = timezone.utc |
| 47 | +# Preserve the original datetime class for type mapping even when patched in tests |
| 48 | +_DATETIME_TYPE = datetime |
45 | 49 |
|
46 | 50 | app = typer.Typer(help="SLURM Job Monitor - Collect and analyze job efficiency metrics") |
47 | 51 | console = Console() |
@@ -672,6 +676,39 @@ def to_dict(self) -> dict[str, Any]: |
672 | 676 | """Convert to dictionary for DataFrame creation.""" |
673 | 677 | return self.model_dump() |
674 | 678 |
|
| 679 | + @classmethod |
| 680 | + def get_polars_schema(cls) -> dict[str, pl.DataType]: |
| 681 | + """Get Polars schema derived from Pydantic model fields.""" |
| 682 | + mapping: dict[type[Any], pl.DataType] = { |
| 683 | + str: pl.Utf8, |
| 684 | + int: pl.Int64, |
| 685 | + float: pl.Float64, |
| 686 | + bool: pl.Boolean, |
| 687 | + # All datetime fields should be UTC |
| 688 | + _DATETIME_TYPE: pl.Datetime("us", "UTC"), |
| 689 | + } |
| 690 | + |
| 691 | + schema: dict[str, pl.DataType] = {} |
| 692 | + for field_name, field_info in cls.model_fields.items(): |
| 693 | + annotation = field_info.annotation |
| 694 | + |
| 695 | + # Handle Optional types (Union[T, None] or T | None) |
| 696 | + origin = typing.get_origin(annotation) |
| 697 | + if origin in (typing.Union, types.UnionType): |
| 698 | + args = typing.get_args(annotation) |
| 699 | + non_none_args = [arg for arg in args if arg is not type(None)] |
| 700 | + if non_none_args: |
| 701 | + annotation = non_none_args[0] |
| 702 | + |
| 703 | + mapped_type: pl.DataType | None = None |
| 704 | + if isinstance(annotation, type): |
| 705 | + mapped_type = mapping.get(annotation) |
| 706 | + |
| 707 | + # Map Python types to Polars types (default to Utf8 for unknown types) |
| 708 | + schema[field_name] = mapped_type or pl.Utf8 |
| 709 | + |
| 710 | + return schema |
| 711 | + |
675 | 712 |
|
676 | 713 | class DateCompletionTracker(BaseModel): |
677 | 714 | """Tracks which dates have been fully processed and don't need re-collection.""" |
@@ -780,10 +817,16 @@ def _parse_datetime(date_str: str | None) -> datetime | None: |
780 | 817 | return None |
781 | 818 | try: |
782 | 819 | # SLURM uses ISO format: 2025-08-19T10:30:00 |
783 | | - return datetime.fromisoformat(date_str) |
| 820 | + dt = datetime.fromisoformat(date_str) |
784 | 821 | except (ValueError, AttributeError): |
785 | 822 | return None |
786 | 823 |
|
| 824 | + # Ensure timezone-aware (assume UTC if naive) |
| 825 | + if dt.tzinfo is None: |
| 826 | + dt = dt.replace(tzinfo=UTC) |
| 827 | + |
| 828 | + return dt |
| 829 | + |
787 | 830 |
|
788 | 831 | def _parse_gpu_count(alloc_tres: str) -> int: |
789 | 832 | """Parse GPU count from AllocTRES string. |
@@ -1330,10 +1373,14 @@ def _processed_jobs_to_dataframe( |
1330 | 1373 | DataFrame with job data |
1331 | 1374 |
|
1332 | 1375 | """ |
1333 | | - return pl.DataFrame( |
1334 | | - [j.to_dict() for j in processed_jobs], |
1335 | | - infer_schema_length=None, |
1336 | | - ) |
| 1376 | + # Create DataFrame with explicit schema to prevent Null type inference |
| 1377 | + schema = ProcessedJob.get_polars_schema() |
| 1378 | + |
| 1379 | + if not processed_jobs: |
| 1380 | + return pl.DataFrame(schema=schema) |
| 1381 | + |
| 1382 | + data_dicts = [j.to_dict() for j in processed_jobs] |
| 1383 | + return pl.DataFrame(data_dicts, schema=schema) |
1337 | 1384 |
|
1338 | 1385 |
|
1339 | 1386 | def _save_processed_jobs_to_parquet( |
|
0 commit comments