Skip to content

Commit

Permalink
Use a join for upsert deduplication (#1685)
Browse files Browse the repository at this point in the history
This changes the deduplication logic to use join to duplicate the rows.
While the original design wasn't wrong, it is more efficient to push
things down into PyArrow to have better multi-threading and no GIL.

I did a small benchmark:

```python
import time
import pyarrow as pa

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.schema import Schema
from pyiceberg.types import NestedField, StringType, IntegerType


def _drop_table(catalog: Catalog, identifier: str) -> None:
    try:
        catalog.drop_table(identifier)
    except NoSuchTableError:
        pass
def test_vo(session_catalog: Catalog):
    catalog = session_catalog
    identifier = "default.test_upsert_benchmark"
    _drop_table(catalog, identifier)

    schema = Schema(
        NestedField(1, "idx", IntegerType(), required=True),
        NestedField(2, "number", IntegerType(), required=True),
        # Mark City as the identifier field, also known as the primary-key
        identifier_field_ids=[1],
    )

    tbl = catalog.create_table(identifier, schema=schema)

    arrow_schema = pa.schema(
        [
            pa.field("idx", pa.int32(), nullable=False),
            pa.field("number", pa.int32(), nullable=False),
        ]
    )

    # Write some data
    df = pa.Table.from_pylist(
        [
            {"idx": idx, "number": idx}
            for idx in range(1, 100000)
        ],
        schema=arrow_schema,
    )
    tbl.append(df)

    df_upsert = pa.Table.from_pylist(
        # Overlap
        [
            {"idx": idx, "number": idx}
            for idx in range(80000, 90000)
        ]+
        # Update
        [
            {"idx": idx, "number": idx + 1}
            for idx in range(90000, 100000)
        ]
        # Insert
        + [
            {"idx": idx, "number": idx}
            for idx in range(100000, 110000)],
        schema=arrow_schema,
    )

    start = time.time()

    tbl.upsert(df_upsert)

    stop = time.time()

    print(f"Took {stop-start} seconds")
```

And the result was:

```
Took 2.0412521362304688 seconds on the fd-join branch
Took 3.5236432552337646 seconds on lastest main
```
  • Loading branch information
Fokko authored Feb 21, 2025
1 parent 68a08b1 commit b95e792
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 39 deletions.
7 changes: 7 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,13 @@ def upsert(
if upsert_util.has_duplicate_rows(df, join_cols):
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()
Expand Down
57 changes: 18 additions & 39 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,51 +59,30 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
"""
Return a table with rows that need to be updated in the target table based on the join columns.
When a row is matched, an additional scan is done to evaluate the non-key columns to detect if an actual change has occurred.
Only matched rows that have an actual change to a non-key column value will be returned in the final output.
The table is joined on the identifier columns, and then checked if there are any updated rows.
Those are selected and everything is renamed correctly.
"""
all_columns = set(source_table.column_names)
join_cols_set = set(join_cols)
non_key_cols = all_columns - join_cols_set

non_key_cols = list(all_columns - join_cols_set)
if has_duplicate_rows(target_table, join_cols):
raise ValueError("Target table has duplicate rows, aborting upsert")

if len(target_table) == 0:
# When the target table is empty, there is nothing to update :)
return source_table.schema.empty_table()

match_expr = functools.reduce(operator.and_, [pc.field(col).isin(target_table.column(col).to_pylist()) for col in join_cols])

matching_source_rows = source_table.filter(match_expr)

rows_to_update = []

for index in range(matching_source_rows.num_rows):
source_row = matching_source_rows.slice(index, 1)

target_filter = functools.reduce(operator.and_, [pc.field(col) == source_row.column(col)[0].as_py() for col in join_cols])

matching_target_row = target_table.filter(target_filter)

if matching_target_row.num_rows > 0:
needs_update = False

for non_key_col in non_key_cols:
source_value = source_row.column(non_key_col)[0].as_py()
target_value = matching_target_row.column(non_key_col)[0].as_py()

if source_value != target_value:
needs_update = True
break

if needs_update:
rows_to_update.append(source_row)

if rows_to_update:
rows_to_update_table = pa.concat_tables(rows_to_update)
else:
rows_to_update_table = source_table.schema.empty_table()

common_columns = set(source_table.column_names).intersection(set(target_table.column_names))
rows_to_update_table = rows_to_update_table.select(list(common_columns))

return rows_to_update_table
diff_expr = functools.reduce(operator.or_, [pc.field(f"{col}-lhs") != pc.field(f"{col}-rhs") for col in non_key_cols])

return (
source_table
# We already know that the schema is compatible, this is to fix large_ types
.cast(target_table.schema)
.join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs")
.filter(diff_expr)
.drop_columns([f"{col}-rhs" for col in non_key_cols])
.rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names})
# Finally cast to the original schema since it doesn't carry nullability:
# https://github.com/apache/arrow/issues/45557
).cast(target_table.schema)
42 changes: 42 additions & 0 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,48 @@ def test_create_match_filter_single_condition() -> None:
)


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_duplicate_rows_in_table"

_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "city", StringType(), required=True),
NestedField(2, "inhabitants", IntegerType(), required=True),
# Mark City as the identifier field, also known as the primary-key
identifier_field_ids=[1],
)

tbl = catalog.create_table(identifier, schema=schema)

arrow_schema = pa.schema(
[
pa.field("city", pa.string(), nullable=False),
pa.field("inhabitants", pa.int32(), nullable=False),
]
)

# Write some data
df = pa.Table.from_pylist(
[
{"city": "Drachten", "inhabitants": 45019},
{"city": "Drachten", "inhabitants": 45019},
],
schema=arrow_schema,
)
tbl.append(df)

df = pa.Table.from_pylist(
[
# Will be updated, the inhabitants has been updated
{"city": "Drachten", "inhabitants": 45505},
],
schema=arrow_schema,
)

with pytest.raises(ValueError, match="Target table has duplicate rows, aborting upsert"):
_ = tbl.upsert(df)


def test_upsert_without_identifier_fields(catalog: Catalog) -> None:
identifier = "default.test_upsert_without_identifier_fields"
_drop_table(catalog, identifier)
Expand Down

0 comments on commit b95e792

Please sign in to comment.