Skip to content

Commit 00cfd92

Browse files
committed
fix: Validate SetStatisticsUpdate correctly (fixes #2865)
Previously the pydantic @model_validator would fail because it assumed statistics was a model instance. In a "before"" validator that is not necessarily the case. Check type explicitly with isinstance instead, and handle `dict` case too.
1 parent b0a7878 commit 00cfd92

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

pyiceberg/table/update/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from abc import ABC, abstractmethod
2222
from datetime import datetime
2323
from functools import singledispatch
24-
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar, cast
24+
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar
2525

2626
from pydantic import Field, field_validator, model_serializer, model_validator
2727

@@ -181,9 +181,15 @@ class SetStatisticsUpdate(IcebergBaseModel):
181181

182182
@model_validator(mode="before")
183183
def validate_snapshot_id(cls, data: dict[str, Any]) -> dict[str, Any]:
184-
stats = cast(StatisticsFile, data["statistics"])
184+
snapshot_id = None
185185

186-
data["snapshot_id"] = stats.snapshot_id
186+
stats = data["statistics"]
187+
if isinstance(stats, StatisticsFile):
188+
snapshot_id = stats.snapshot_id
189+
elif isinstance(stats, dict):
190+
snapshot_id = stats.get("snapshot_id")
191+
192+
data["snapshot_id"] = snapshot_id
187193

188194
return data
189195

tests/table/test_init.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any
2222

2323
import pytest
24-
from pydantic import ValidationError
24+
from pydantic import ValidationError, BaseModel
2525
from sortedcontainers import SortedList
2626

2727
from pyiceberg.catalog.noop import NoopCatalog
@@ -1391,6 +1391,8 @@ def test_set_statistics_update(table_v2_with_statistics: Table) -> None:
13911391
statistics=statistics_file,
13921392
)
13931393

1394+
assert model_roundtrips(update)
1395+
13941396
new_metadata = update_table_metadata(
13951397
table_v2_with_statistics.metadata,
13961398
(update,),
@@ -1575,3 +1577,14 @@ def test_add_snapshot_update_updates_next_row_id(table_v3: Table) -> None:
15751577

15761578
new_metadata = update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),))
15771579
assert new_metadata.next_row_id == 11
1580+
1581+
1582+
def model_roundtrips(model: BaseModel) -> bool:
1583+
"""Helper assertion that tests if a pydantic model roundtrips
1584+
successfully.
1585+
"""
1586+
__tracebackhide__ = True
1587+
model_data = model.model_dump()
1588+
if model != type(model).model_validate(model_data):
1589+
pytest.fail(f"model {type(model)} did not roundtrip successfully")
1590+
return True

0 commit comments

Comments
 (0)