Skip to content

Commit ba18a02

Browse files
geruhchinmay-bhat
andcommitted
feat: Add support set current snapshot
Co-authored-by: Chinmay Bhat <12948588+chinmay-bhat@users.noreply.github.com>
1 parent b0a7878 commit ba18a02

File tree

4 files changed

+319
-2
lines changed

4 files changed

+319
-2
lines changed

pyiceberg/table/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,12 @@ def __exit__(self, exctype: type[BaseException] | None, excinst: BaseException |
275275
if exctype is None and excinst is None and exctb is None:
276276
self.commit_transaction()
277277

278-
def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequirement, ...] = ()) -> Transaction:
278+
def _apply(
279+
self,
280+
updates: tuple[TableUpdate, ...],
281+
requirements: tuple[TableRequirement, ...] = (),
282+
commit_transaction_if_autocommit: bool = True,
283+
) -> Transaction:
279284
"""Check if the requirements are met, and applies the updates to the metadata."""
280285
for requirement in requirements:
281286
requirement.validate(self.table_metadata)
@@ -289,7 +294,7 @@ def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequ
289294
if type(new_requirement) not in existing_requirements:
290295
self._requirements = self._requirements + (new_requirement,)
291296

292-
if self._autocommit:
297+
if self._autocommit and commit_transaction_if_autocommit:
293298
self.commit_transaction()
294299

295300
return self

pyiceberg/table/update/snapshot.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,13 @@ def _commit(self) -> UpdatesAndRequirements:
843843
"""Apply the pending changes and commit."""
844844
return self._updates, self._requirements
845845

846+
def _commit_if_ref_updates_exist(self) -> None:
847+
"""Commit any pending ref updates to the transaction."""
848+
if self._updates:
849+
self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False)
850+
self._updates = ()
851+
self._requirements = ()
852+
846853
def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
847854
"""Remove a snapshot ref.
848855
@@ -941,6 +948,44 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
941948
"""
942949
return self._remove_ref_snapshot(ref_name=branch_name)
943950

951+
def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | None = None) -> ManageSnapshots:
952+
"""Set the current snapshot to a specific snapshot ID or ref.
953+
954+
Args:
955+
snapshot_id: The ID of the snapshot to set as current.
956+
ref_name: The snapshot reference (branch or tag) to set as current.
957+
958+
Returns:
959+
This for method chaining.
960+
961+
Raises:
962+
ValueError: If neither or both arguments are provided, or if the snapshot/ref does not exist.
963+
"""
964+
self._commit_if_ref_updates_exist()
965+
966+
if (snapshot_id is None) == (ref_name is None):
967+
raise ValueError("Either snapshot_id or ref_name must be provided, not both")
968+
969+
target_snapshot_id: int
970+
if snapshot_id is not None:
971+
target_snapshot_id = snapshot_id
972+
else:
973+
if ref_name not in self._transaction.table_metadata.refs:
974+
raise ValueError(f"Cannot find matching snapshot ID for ref: {ref_name}")
975+
target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id
976+
977+
if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id) is None:
978+
raise ValueError(f"Cannot set current snapshot to unknown snapshot id: {target_snapshot_id}")
979+
980+
update, requirement = self._transaction._set_ref_snapshot(
981+
snapshot_id=target_snapshot_id,
982+
ref_name=MAIN_BRANCH,
983+
type="branch",
984+
)
985+
self._updates += update
986+
self._requirements += requirement
987+
return self
988+
944989

945990
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
946991
"""Expire snapshots by ID.

tests/integration/test_snapshot_operations.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,91 @@ def test_remove_branch(catalog: Catalog) -> None:
7272
# now, remove the branch
7373
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
7474
assert tbl.metadata.refs.get(branch_name, None) is None
75+
76+
77+
@pytest.mark.integration
78+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
79+
def test_set_current_snapshot(catalog: Catalog) -> None:
80+
identifier = "default.test_table_snapshot_operations"
81+
tbl = catalog.load_table(identifier)
82+
assert len(tbl.history()) > 2
83+
84+
# first get the current snapshot and an older one
85+
current_snapshot_id = tbl.history()[-1].snapshot_id
86+
older_snapshot_id = tbl.history()[-2].snapshot_id
87+
88+
# set the current snapshot to the older one
89+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit()
90+
91+
tbl = catalog.load_table(identifier)
92+
updated_snapshot = tbl.current_snapshot()
93+
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id
94+
95+
# restore table
96+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
97+
tbl = catalog.load_table(identifier)
98+
restored_snapshot = tbl.current_snapshot()
99+
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id
100+
101+
102+
@pytest.mark.integration
103+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
104+
def test_set_current_snapshot_by_ref(catalog: Catalog) -> None:
105+
identifier = "default.test_table_snapshot_operations"
106+
tbl = catalog.load_table(identifier)
107+
assert len(tbl.history()) > 2
108+
109+
# first get the current snapshot and an older one
110+
current_snapshot_id = tbl.history()[-1].snapshot_id
111+
older_snapshot_id = tbl.history()[-2].snapshot_id
112+
assert older_snapshot_id != current_snapshot_id
113+
114+
# create a tag pointing to the older snapshot
115+
tag_name = "my-tag"
116+
tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name).commit()
117+
118+
# set current snapshot using the tag name
119+
tbl = catalog.load_table(identifier)
120+
tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit()
121+
122+
tbl = catalog.load_table(identifier)
123+
updated_snapshot = tbl.current_snapshot()
124+
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id
125+
126+
# restore table
127+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
128+
tbl = catalog.load_table(identifier)
129+
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
130+
assert tbl.metadata.refs.get(tag_name, None) is None
131+
132+
133+
@pytest.mark.integration
134+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
135+
def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
136+
identifier = "default.test_table_snapshot_operations"
137+
tbl = catalog.load_table(identifier)
138+
assert len(tbl.history()) > 2
139+
140+
current_snapshot_id = tbl.history()[-1].snapshot_id
141+
older_snapshot_id = tbl.history()[-2].snapshot_id
142+
assert older_snapshot_id != current_snapshot_id
143+
144+
# create a tag and use it to set current snapshot
145+
tag_name = "my-tag"
146+
(
147+
tbl.manage_snapshots()
148+
.create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name)
149+
.set_current_snapshot(ref_name=tag_name)
150+
.commit()
151+
)
152+
153+
tbl = catalog.load_table(identifier)
154+
updated_snapshot = tbl.current_snapshot()
155+
assert updated_snapshot
156+
assert updated_snapshot.snapshot_id == older_snapshot_id
157+
158+
# restore table
159+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
160+
tbl = catalog.load_table(identifier)
161+
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
162+
assert tbl.metadata.refs.get(tag_name, None) is None
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from unittest.mock import MagicMock
18+
from uuid import uuid4
19+
20+
import pytest
21+
22+
from pyiceberg.table import CommitTableResponse, Table
23+
from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate
24+
25+
26+
def _mock_commit_response(table: Table) -> CommitTableResponse:
27+
return CommitTableResponse(
28+
metadata=table.metadata,
29+
metadata_location="s3://bucket/tbl",
30+
uuid=uuid4(),
31+
)
32+
33+
34+
def _get_updates(mock_catalog: MagicMock) -> tuple[TableUpdate, ...]:
35+
args, _ = mock_catalog.commit_table.call_args
36+
return args[2]
37+
38+
39+
def test_set_current_snapshot_basic(table_v2: Table) -> None:
40+
snapshot_one = 3051729675574597004
41+
42+
table_v2.catalog = MagicMock()
43+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
44+
45+
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).commit()
46+
47+
table_v2.catalog.commit_table.assert_called_once()
48+
49+
updates = _get_updates(table_v2.catalog)
50+
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]
51+
52+
assert len(set_ref_updates) == 1
53+
update = set_ref_updates[0]
54+
assert update.snapshot_id == snapshot_one
55+
assert update.ref_name == "main"
56+
assert update.type == "branch"
57+
58+
59+
def test_set_current_snapshot_unknown_id(table_v2: Table) -> None:
60+
invalid_snapshot_id = 1234567890000
61+
table_v2.catalog = MagicMock()
62+
63+
with pytest.raises(ValueError, match="Cannot set current snapshot to unknown snapshot id"):
64+
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=invalid_snapshot_id).commit()
65+
66+
table_v2.catalog.commit_table.assert_not_called()
67+
68+
69+
def test_set_current_snapshot_to_current(table_v2: Table) -> None:
70+
current_snapshot = table_v2.current_snapshot()
71+
assert current_snapshot is not None
72+
73+
table_v2.catalog = MagicMock()
74+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
75+
76+
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit()
77+
78+
table_v2.catalog.commit_table.assert_called_once()
79+
80+
81+
def test_set_current_snapshot_chained_with_tag(table_v2: Table) -> None:
82+
snapshot_one = 3051729675574597004
83+
table_v2.catalog = MagicMock()
84+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
85+
86+
(table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).create_tag(snapshot_one, "my-tag").commit())
87+
88+
table_v2.catalog.commit_table.assert_called_once()
89+
90+
updates = _get_updates(table_v2.catalog)
91+
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]
92+
93+
assert len(set_ref_updates) == 2
94+
assert {u.ref_name for u in set_ref_updates} == {"main", "my-tag"}
95+
96+
97+
def test_set_current_snapshot_with_extensive_snapshots(table_v2_with_extensive_snapshots: Table) -> None:
98+
snapshots = table_v2_with_extensive_snapshots.metadata.snapshots
99+
assert len(snapshots) > 100
100+
101+
target_snapshot = snapshots[50].snapshot_id
102+
103+
table_v2_with_extensive_snapshots.catalog = MagicMock()
104+
table_v2_with_extensive_snapshots.catalog.commit_table.return_value = _mock_commit_response(table_v2_with_extensive_snapshots)
105+
106+
table_v2_with_extensive_snapshots.manage_snapshots().set_current_snapshot(snapshot_id=target_snapshot).commit()
107+
108+
table_v2_with_extensive_snapshots.catalog.commit_table.assert_called_once()
109+
110+
updates = _get_updates(table_v2_with_extensive_snapshots.catalog)
111+
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]
112+
113+
assert len(set_ref_updates) == 1
114+
assert set_ref_updates[0].snapshot_id == target_snapshot
115+
116+
117+
def test_set_current_snapshot_by_ref_name(table_v2: Table) -> None:
118+
current_snapshot = table_v2.current_snapshot()
119+
assert current_snapshot is not None
120+
121+
table_v2.catalog = MagicMock()
122+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
123+
124+
table_v2.manage_snapshots().set_current_snapshot(ref_name="main").commit()
125+
126+
updates = _get_updates(table_v2.catalog)
127+
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]
128+
129+
assert len(set_ref_updates) == 1
130+
assert set_ref_updates[0].snapshot_id == current_snapshot.snapshot_id
131+
assert set_ref_updates[0].ref_name == "main"
132+
133+
134+
def test_set_current_snapshot_unknown_ref(table_v2: Table) -> None:
135+
table_v2.catalog = MagicMock()
136+
137+
with pytest.raises(ValueError, match="Cannot find matching snapshot ID for ref: nonexistent"):
138+
table_v2.manage_snapshots().set_current_snapshot(ref_name="nonexistent").commit()
139+
140+
table_v2.catalog.commit_table.assert_not_called()
141+
142+
143+
def test_set_current_snapshot_requires_one_argument(table_v2: Table) -> None:
144+
table_v2.catalog = MagicMock()
145+
146+
with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"):
147+
table_v2.manage_snapshots().set_current_snapshot().commit()
148+
149+
with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"):
150+
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=123, ref_name="main").commit()
151+
152+
table_v2.catalog.commit_table.assert_not_called()
153+
154+
155+
def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None:
156+
snapshot_one = 3051729675574597004
157+
table_v2.catalog = MagicMock()
158+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
159+
160+
# create a tag and immediately use it to set current snapshot
161+
(
162+
table_v2.manage_snapshots()
163+
.create_tag(snapshot_id=snapshot_one, tag_name="new-tag")
164+
.set_current_snapshot(ref_name="new-tag")
165+
.commit()
166+
)
167+
168+
table_v2.catalog.commit_table.assert_called_once()
169+
170+
updates = _get_updates(table_v2.catalog)
171+
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]
172+
173+
# should have the tag and the main branch update
174+
assert len(set_ref_updates) == 2
175+
assert {u.ref_name for u in set_ref_updates} == {"new-tag", "main"}
176+
177+
# The main branch should point to the same snapshot as the tag
178+
main_update = next(u for u in set_ref_updates if u.ref_name == "main")
179+
assert main_update.snapshot_id == snapshot_one

0 commit comments

Comments
 (0)