Skip to content

Commit ba3eb24

Browse files
committed
fix: Validate SetStatisticsUpdate correctly (fixes #2865)
Previously the pydantic @model_validator would fail because it assumed statistics was a model instance. In a "before"" validator that is not necessarily the case. Check type explicitly with isinstance instead, and handle `dict` case too.
1 parent b0a7878 commit ba3eb24

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed

pyiceberg/table/update/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from abc import ABC, abstractmethod
2222
from datetime import datetime
2323
from functools import singledispatch
24-
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar, cast
24+
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar
2525

2626
from pydantic import Field, field_validator, model_serializer, model_validator
2727

@@ -181,9 +181,15 @@ class SetStatisticsUpdate(IcebergBaseModel):
181181

182182
@model_validator(mode="before")
183183
def validate_snapshot_id(cls, data: dict[str, Any]) -> dict[str, Any]:
184-
stats = cast(StatisticsFile, data["statistics"])
185-
186-
data["snapshot_id"] = stats.snapshot_id
184+
stats = data["statistics"]
185+
if isinstance(stats, StatisticsFile):
186+
snapshot_id = stats.snapshot_id
187+
elif isinstance(stats, dict):
188+
snapshot_id = stats.get("snapshot-id")
189+
else:
190+
snapshot_id = None
191+
192+
data["snapshot_id"] = snapshot_id
187193

188194
return data
189195

tests/table/test_init.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any
2222

2323
import pytest
24-
from pydantic import ValidationError
24+
from pydantic import BaseModel, ValidationError
2525
from sortedcontainers import SortedList
2626

2727
from pyiceberg.catalog.noop import NoopCatalog
@@ -1391,6 +1391,8 @@ def test_set_statistics_update(table_v2_with_statistics: Table) -> None:
13911391
statistics=statistics_file,
13921392
)
13931393

1394+
assert model_roundtrips(update)
1395+
13941396
new_metadata = update_table_metadata(
13951397
table_v2_with_statistics.metadata,
13961398
(update,),
@@ -1425,6 +1427,57 @@ def test_set_statistics_update(table_v2_with_statistics: Table) -> None:
14251427
assert json.loads(updated_statistics[0].model_dump_json()) == json.loads(expected)
14261428

14271429

1430+
def test_set_statistics_update_handles_deprecated_snapshot_id(table_v2_with_statistics: Table) -> None:
1431+
snapshot_id = table_v2_with_statistics.metadata.current_snapshot_id
1432+
1433+
blob_metadata = BlobMetadata(
1434+
type="apache-datasketches-theta-v1",
1435+
snapshot_id=snapshot_id,
1436+
sequence_number=2,
1437+
fields=[1],
1438+
properties={"prop-key": "prop-value"},
1439+
)
1440+
1441+
statistics_file = StatisticsFile(
1442+
snapshot_id=snapshot_id,
1443+
statistics_path="s3://bucket/warehouse/stats.puffin",
1444+
file_size_in_bytes=124,
1445+
file_footer_size_in_bytes=27,
1446+
blob_metadata=[blob_metadata],
1447+
)
1448+
update_with_model = SetStatisticsUpdate(statistics=statistics_file)
1449+
assert model_roundtrips(update_with_model)
1450+
assert update_with_model.snapshot_id == snapshot_id
1451+
1452+
update_with_dict = SetStatisticsUpdate.model_validate({"statistics": statistics_file.model_dump()})
1453+
assert model_roundtrips(update_with_dict)
1454+
assert update_with_dict.snapshot_id == snapshot_id
1455+
1456+
update_json = """
1457+
{
1458+
"statistics":
1459+
{
1460+
"snapshot-id": 3051729675574597004,
1461+
"statistics-path": "s3://a/b/stats.puffin",
1462+
"file-size-in-bytes": 413,
1463+
"file-footer-size-in-bytes": 42,
1464+
"blob-metadata": [
1465+
{
1466+
"type": "apache-datasketches-theta-v1",
1467+
"snapshot-id": 3051729675574597004,
1468+
"sequence-number": 1,
1469+
"fields": [1]
1470+
}
1471+
]
1472+
}
1473+
}
1474+
"""
1475+
1476+
update_with_json = SetStatisticsUpdate.model_validate_json(update_json)
1477+
assert model_roundtrips(update_with_json)
1478+
assert update_with_json.snapshot_id == snapshot_id
1479+
1480+
14281481
def test_remove_statistics_update(table_v2_with_statistics: Table) -> None:
14291482
update = RemoveStatisticsUpdate(
14301483
snapshot_id=3055729675574597004,
@@ -1575,3 +1628,14 @@ def test_add_snapshot_update_updates_next_row_id(table_v3: Table) -> None:
15751628

15761629
new_metadata = update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),))
15771630
assert new_metadata.next_row_id == 11
1631+
1632+
1633+
def model_roundtrips(model: BaseModel) -> bool:
1634+
"""Helper assertion that tests if a pydantic model roundtrips
1635+
successfully.
1636+
"""
1637+
__tracebackhide__ = True
1638+
model_data = model.model_dump()
1639+
if model != type(model).model_validate(model_data):
1640+
pytest.fail(f"model {type(model)} did not roundtrip successfully")
1641+
return True

0 commit comments

Comments
 (0)