Skip to content

Commit

Permalink
Upsert: Reuse existing expression to detect rows to be inserted (#1662)
Browse files Browse the repository at this point in the history
Also slight refactor of the tests to bring it more in line with the rest
  • Loading branch information
Fokko authored Feb 14, 2025
1 parent ee11bb0 commit 8014b6c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 74 deletions.
17 changes: 13 additions & 4 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
manifest_evaluator,
)
from pyiceberg.io import FileIO, load_file_io
from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow
from pyiceberg.io.pyarrow import ArrowScan, expression_to_pyarrow, schema_to_pyarrow
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
Expand Down Expand Up @@ -1101,7 +1101,12 @@ def name_mapping(self) -> Optional[NameMapping]:
return self.metadata.name_mapping()

def upsert(
self, df: pa.Table, join_cols: list[str], when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True
self,
df: pa.Table,
join_cols: list[str],
when_matched_update_all: bool = True,
when_not_matched_insert_all: bool = True,
case_sensitive: bool = True,
) -> UpsertResult:
"""Shorthand API for performing an upsert to an iceberg table.
Expand All @@ -1111,6 +1116,7 @@ def upsert(
join_cols: The columns to join on. These are essentially analogous to primary keys
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
case_sensitive: Bool indicating if the match should be case-sensitive
Example Use Cases:
Case 1: Both Parameters = True (Full Upsert)
Expand Down Expand Up @@ -1144,7 +1150,7 @@ def upsert(

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
matched_iceberg_table = self.scan(row_filter=matched_predicate).to_arrow()
matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()

update_row_cnt = 0
insert_row_cnt = 0
Expand All @@ -1164,7 +1170,10 @@ def upsert(
tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)

if when_not_matched_insert_all:
rows_to_insert = upsert_util.get_rows_to_insert(df, matched_iceberg_table, join_cols)
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
rows_to_insert = df.filter(~expr_match_arrow)

insert_row_cnt = len(rows_to_insert)

Expand Down
24 changes: 0 additions & 24 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,3 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
rows_to_update_table = rows_to_update_table.select(list(common_columns))

return rows_to_update_table


def get_rows_to_insert(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
source_filter_expr = pc.scalar(True)

for col in join_cols:
target_values = target_table.column(col).to_pylist()
expr = pc.field(col).isin(target_values)

if source_filter_expr is None:
source_filter_expr = expr
else:
source_filter_expr = source_filter_expr & expr

non_matching_expr = ~source_filter_expr

source_columns = set(source_table.column_names)
target_columns = set(target_table.column_names)

common_columns = source_columns.intersection(target_columns)

non_matching_rows = source_table.filter(non_matching_expr).select(common_columns)

return non_matching_rows
99 changes: 53 additions & 46 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pathlib import PosixPath

import pytest
from datafusion import SessionContext
from pyarrow import Table as pa_table

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.table import UpsertResult
from tests.catalog.test_base import InMemoryCatalog, Table

_TEST_NAMESPACE = "test_ns"

@pytest.fixture
def catalog(tmp_path: PosixPath) -> InMemoryCatalog:
catalog = InMemoryCatalog("test.in_memory.catalog", warehouse=tmp_path.absolute().as_posix())
catalog.create_namespace("default")
return catalog


def _drop_table(catalog: Catalog, identifier: str) -> None:
try:
catalog.drop_table(identifier)
except NoSuchTableError:
pass


def show_iceberg_table(table: Table, ctx: SessionContext) -> None:
Expand Down Expand Up @@ -72,7 +88,7 @@ def gen_source_dataset(start_row: int, end_row: int, composite_key: bool, add_du


def gen_target_iceberg_table(
start_row: int, end_row: int, composite_key: bool, ctx: SessionContext, catalog: InMemoryCatalog, namespace: str
start_row: int, end_row: int, composite_key: bool, ctx: SessionContext, catalog: InMemoryCatalog, identifier: str
) -> Table:
additional_columns = ", t.order_id + 1000 as order_line_id" if composite_key else ""

Expand All @@ -83,7 +99,7 @@ def gen_target_iceberg_table(
from t
""").to_arrow_table()

table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
table = catalog.create_table(identifier, df.schema)

table.append(df)

Expand All @@ -95,13 +111,6 @@ def assert_upsert_result(res: UpsertResult, expected_updated: int, expected_inse
assert res.rows_inserted == expected_inserted, f"rows inserted should be {expected_inserted}, but got {res.rows_inserted}"


@pytest.fixture(scope="session")
def catalog_conn() -> InMemoryCatalog:
catalog = InMemoryCatalog("test")
catalog.create_namespace(namespace=_TEST_NAMESPACE)
yield catalog


@pytest.mark.parametrize(
"join_cols, src_start_row, src_end_row, target_start_row, target_end_row, when_matched_update_all, when_not_matched_insert_all, expected_updated, expected_inserted",
[
Expand All @@ -112,7 +121,7 @@ def catalog_conn() -> InMemoryCatalog:
],
)
def test_merge_rows(
catalog_conn: InMemoryCatalog,
catalog: Catalog,
join_cols: list[str],
src_start_row: int,
src_end_row: int,
Expand All @@ -123,12 +132,13 @@ def test_merge_rows(
expected_updated: int,
expected_inserted: int,
) -> None:
ctx = SessionContext()
identifier = "default.test_merge_rows"
_drop_table(catalog, identifier)

catalog = catalog_conn
ctx = SessionContext()

source_df = gen_source_dataset(src_start_row, src_end_row, False, False, ctx)
ice_table = gen_target_iceberg_table(target_start_row, target_end_row, False, ctx, catalog, _TEST_NAMESPACE)
ice_table = gen_target_iceberg_table(target_start_row, target_end_row, False, ctx, catalog, identifier)
res = ice_table.upsert(
df=source_df,
join_cols=join_cols,
Expand All @@ -138,13 +148,13 @@ def test_merge_rows(

assert_upsert_result(res, expected_updated, expected_inserted)

catalog.drop_table(f"{_TEST_NAMESPACE}.target")


def test_merge_scenario_skip_upd_row(catalog_conn: InMemoryCatalog) -> None:
def test_merge_scenario_skip_upd_row(catalog: Catalog) -> None:
"""
tests a single insert and update; skips a row that does not need to be updated
"""
identifier = "default.test_merge_scenario_skip_upd_row"
_drop_table(catalog, identifier)

ctx = SessionContext()

Expand All @@ -154,8 +164,7 @@ def test_merge_scenario_skip_upd_row(catalog_conn: InMemoryCatalog) -> None:
select 2 as order_id, date '2021-01-01' as order_date, 'A' as order_type
""").to_arrow_table()

catalog = catalog_conn
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
table = catalog.create_table(identifier, df.schema)

table.append(df)

Expand All @@ -174,24 +183,24 @@ def test_merge_scenario_skip_upd_row(catalog_conn: InMemoryCatalog) -> None:

assert_upsert_result(res, expected_updated, expected_inserted)

catalog.drop_table(f"{_TEST_NAMESPACE}.target")


def test_merge_scenario_date_as_key(catalog_conn: InMemoryCatalog) -> None:
def test_merge_scenario_date_as_key(catalog: Catalog) -> None:
"""
tests a single insert and update; primary key is a date column
"""

ctx = SessionContext()

identifier = "default.test_merge_scenario_date_as_key"
_drop_table(catalog, identifier)

df = ctx.sql("""
select date '2021-01-01' as order_date, 'A' as order_type
union all
select date '2021-01-02' as order_date, 'A' as order_type
""").to_arrow_table()

catalog = catalog_conn
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
table = catalog.create_table(identifier, df.schema)

table.append(df)

Expand All @@ -210,14 +219,15 @@ def test_merge_scenario_date_as_key(catalog_conn: InMemoryCatalog) -> None:

assert_upsert_result(res, expected_updated, expected_inserted)

catalog.drop_table(f"{_TEST_NAMESPACE}.target")


def test_merge_scenario_string_as_key(catalog_conn: InMemoryCatalog) -> None:
def test_merge_scenario_string_as_key(catalog: Catalog) -> None:
"""
tests a single insert and update; primary key is a string column
"""

identifier = "default.test_merge_scenario_string_as_key"
_drop_table(catalog, identifier)

ctx = SessionContext()

df = ctx.sql("""
Expand All @@ -226,8 +236,7 @@ def test_merge_scenario_string_as_key(catalog_conn: InMemoryCatalog) -> None:
select 'def' as order_id, 'A' as order_type
""").to_arrow_table()

catalog = catalog_conn
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
table = catalog.create_table(identifier, df.schema)

table.append(df)

Expand All @@ -246,18 +255,18 @@ def test_merge_scenario_string_as_key(catalog_conn: InMemoryCatalog) -> None:

assert_upsert_result(res, expected_updated, expected_inserted)

catalog.drop_table(f"{_TEST_NAMESPACE}.target")


def test_merge_scenario_composite_key(catalog_conn: InMemoryCatalog) -> None:
def test_merge_scenario_composite_key(catalog: Catalog) -> None:
"""
tests merging 200 rows with a composite key
"""

identifier = "default.test_merge_scenario_composite_key"
_drop_table(catalog, identifier)

ctx = SessionContext()

catalog = catalog_conn
table = gen_target_iceberg_table(1, 200, True, ctx, catalog, _TEST_NAMESPACE)
table = gen_target_iceberg_table(1, 200, True, ctx, catalog, identifier)
source_df = gen_source_dataset(101, 300, True, False, ctx)

res = table.upsert(df=source_df, join_cols=["order_id", "order_line_id"])
Expand All @@ -267,43 +276,41 @@ def test_merge_scenario_composite_key(catalog_conn: InMemoryCatalog) -> None:

assert_upsert_result(res, expected_updated, expected_inserted)

catalog.drop_table(f"{_TEST_NAMESPACE}.target")


def test_merge_source_dups(catalog_conn: InMemoryCatalog) -> None:
def test_merge_source_dups(catalog: Catalog) -> None:
"""
tests duplicate rows in source
"""

identifier = "default.test_merge_source_dups"
_drop_table(catalog, identifier)

ctx = SessionContext()

catalog = catalog_conn
table = gen_target_iceberg_table(1, 10, False, ctx, catalog, _TEST_NAMESPACE)
table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier)
source_df = gen_source_dataset(5, 15, False, True, ctx)

with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"):
table.upsert(df=source_df, join_cols=["order_id"])

catalog.drop_table(f"{_TEST_NAMESPACE}.target")


def test_key_cols_misaligned(catalog_conn: InMemoryCatalog) -> None:
def test_key_cols_misaligned(catalog: Catalog) -> None:
"""
tests join columns missing from one of the tables
"""

identifier = "default.test_key_cols_misaligned"
_drop_table(catalog, identifier)

ctx = SessionContext()

df = ctx.sql("select 1 as order_id, date '2021-01-01' as order_date, 'A' as order_type").to_arrow_table()

catalog = catalog_conn
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
table = catalog.create_table(identifier, df.schema)

table.append(df)

df_src = ctx.sql("select 1 as item_id, date '2021-05-01' as order_date, 'B' as order_type").to_arrow_table()

with pytest.raises(Exception, match=r"""Field ".*" does not exist in schema"""):
table.upsert(df=df_src, join_cols=["order_id"])

catalog.drop_table(f"{_TEST_NAMESPACE}.target")

0 comments on commit 8014b6c

Please sign in to comment.