Skip to content

Commit f5a284f

Browse files
authored
Merge pull request #13 from basnijholt/fix-datetime-schema-mismatch
Fix datetime schema mismatch
2 parents 60321f7 + d0c24e7 commit f5a284f

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

slurm_usage.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import os
2525
import re
2626
import subprocess
27+
import types
28+
import typing
2729
from collections import defaultdict
2830
from concurrent.futures import ThreadPoolExecutor, as_completed
2931
from datetime import datetime, timedelta, timezone
@@ -42,6 +44,8 @@
4244
from rich.table import Table
4345

4446
UTC = timezone.utc
47+
# Preserve the original datetime class for type mapping even when patched in tests
48+
_DATETIME_TYPE = datetime
4549

4650
app = typer.Typer(help="SLURM Job Monitor - Collect and analyze job efficiency metrics")
4751
console = Console()
@@ -672,6 +676,39 @@ def to_dict(self) -> dict[str, Any]:
672676
"""Convert to dictionary for DataFrame creation."""
673677
return self.model_dump()
674678

679+
@classmethod
680+
def get_polars_schema(cls) -> dict[str, pl.DataType]:
681+
"""Get Polars schema derived from Pydantic model fields."""
682+
mapping: dict[type[Any], pl.DataType] = {
683+
str: pl.Utf8,
684+
int: pl.Int64,
685+
float: pl.Float64,
686+
bool: pl.Boolean,
687+
# All datetime fields should be UTC
688+
_DATETIME_TYPE: pl.Datetime("us", "UTC"),
689+
}
690+
691+
schema: dict[str, pl.DataType] = {}
692+
for field_name, field_info in cls.model_fields.items():
693+
annotation = field_info.annotation
694+
695+
# Handle Optional types (Union[T, None] or T | None)
696+
origin = typing.get_origin(annotation)
697+
if origin in (typing.Union, types.UnionType):
698+
args = typing.get_args(annotation)
699+
non_none_args = [arg for arg in args if arg is not type(None)]
700+
if non_none_args:
701+
annotation = non_none_args[0]
702+
703+
mapped_type: pl.DataType | None = None
704+
if isinstance(annotation, type):
705+
mapped_type = mapping.get(annotation)
706+
707+
# Map Python types to Polars types (default to Utf8 for unknown types)
708+
schema[field_name] = mapped_type or pl.Utf8
709+
710+
return schema
711+
675712

676713
class DateCompletionTracker(BaseModel):
677714
"""Tracks which dates have been fully processed and don't need re-collection."""
@@ -780,10 +817,16 @@ def _parse_datetime(date_str: str | None) -> datetime | None:
780817
return None
781818
try:
782819
# SLURM uses ISO format: 2025-08-19T10:30:00
783-
return datetime.fromisoformat(date_str)
820+
dt = datetime.fromisoformat(date_str)
784821
except (ValueError, AttributeError):
785822
return None
786823

824+
# Ensure timezone-aware (assume UTC if naive)
825+
if dt.tzinfo is None:
826+
dt = dt.replace(tzinfo=UTC)
827+
828+
return dt
829+
787830

788831
def _parse_gpu_count(alloc_tres: str) -> int:
789832
"""Parse GPU count from AllocTRES string.
@@ -1330,10 +1373,14 @@ def _processed_jobs_to_dataframe(
13301373
DataFrame with job data
13311374
13321375
"""
1333-
return pl.DataFrame(
1334-
[j.to_dict() for j in processed_jobs],
1335-
infer_schema_length=None,
1336-
)
1376+
# Create DataFrame with explicit schema to prevent Null type inference
1377+
schema = ProcessedJob.get_polars_schema()
1378+
1379+
if not processed_jobs:
1380+
return pl.DataFrame(schema=schema)
1381+
1382+
data_dicts = [j.to_dict() for j in processed_jobs]
1383+
return pl.DataFrame(data_dicts, schema=schema)
13371384

13381385

13391386
def _save_processed_jobs_to_parquet(

tests/test_data_processing.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import os
66
import re
77
import sys
8+
import tempfile
9+
from datetime import datetime, timezone
810
from pathlib import Path
911

1012
import pytest
@@ -402,3 +404,74 @@ def test_parse_gres_multiple_sockets(self) -> None:
402404
cleaned_gres = re.sub(r"\(S:[0-9-]+\)", "", gres)
403405
gpu_parts = cleaned_gres.split(":")
404406
assert int(gpu_parts[-1]) == expected_count
407+
408+
409+
class TestDatetimeSchemaConsistency:
410+
"""Test datetime schema consistency when saving and loading data."""
411+
412+
def test_parse_datetime_returns_utc(self) -> None:
413+
"""Test that _parse_datetime returns UTC timezone-aware datetimes."""
414+
# Test with ISO format string
415+
dt = slurm_usage._parse_datetime("2025-09-20T10:30:00")
416+
assert dt is not None
417+
assert dt.tzinfo is not None
418+
assert dt.tzinfo == timezone.utc
419+
420+
# Test with None
421+
assert slurm_usage._parse_datetime(None) is None
422+
assert slurm_usage._parse_datetime("Unknown") is None
423+
424+
def test_processed_jobs_to_dataframe(self) -> None:
425+
"""Test that processed jobs are correctly converted to DataFrame."""
426+
# Create test ProcessedJob with datetime fields
427+
from slurm_usage import ProcessedJob
428+
429+
job = ProcessedJob(
430+
job_id="test123",
431+
user="alice",
432+
job_name="test_job",
433+
partition="gpus",
434+
state="COMPLETED",
435+
submit_time=datetime(2025, 9, 20, 9, 0, 0, tzinfo=timezone.utc),
436+
start_time=datetime(2025, 9, 20, 10, 0, 0, tzinfo=timezone.utc),
437+
end_time=datetime(2025, 9, 20, 11, 0, 0, tzinfo=timezone.utc),
438+
node_list="node-001",
439+
elapsed_seconds=3600,
440+
alloc_cpus=4,
441+
req_mem_mb=4096,
442+
max_rss_mb=2048,
443+
total_cpu_seconds=7200,
444+
alloc_gpus=1,
445+
cpu_efficiency=50.0,
446+
memory_efficiency=50.0,
447+
cpu_hours_wasted=1.0,
448+
memory_gb_hours_wasted=2.0,
449+
cpu_hours_reserved=2.0,
450+
memory_gb_hours_reserved=4.0,
451+
gpu_hours_reserved=1.0,
452+
is_complete=True,
453+
)
454+
455+
# Convert to DataFrame
456+
df = slurm_usage._processed_jobs_to_dataframe([job])
457+
458+
# Check DataFrame was created correctly
459+
assert len(df) == 1
460+
assert df["job_id"][0] == "test123"
461+
assert df["user"][0] == "alice"
462+
463+
def test_load_recent_data_handles_empty_directory(self) -> None:
464+
"""Test that _load_recent_data handles empty directory gracefully."""
465+
with tempfile.TemporaryDirectory() as tmpdir:
466+
config = slurm_usage.Config(
467+
data_dir=Path(tmpdir),
468+
groups={},
469+
user_to_group={},
470+
)
471+
472+
processed_dir = Path(tmpdir) / "processed"
473+
processed_dir.mkdir(parents=True, exist_ok=True)
474+
475+
# Should return None for empty directory
476+
result = slurm_usage._load_recent_data(config, days=1)
477+
assert result is None

0 commit comments

Comments
 (0)