Skip to content

Commit 50c7f21

Browse files
committed
feat(api): implement upsert() using MERGE INTO
1 parent e4e582a commit 50c7f21

File tree

2 files changed

+177
-1
lines changed

2 files changed

+177
-1
lines changed

ibis/backends/sql/__init__.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def insert(
423423
Parameters
424424
----------
425425
name
426-
The name of the table to which data needs will be inserted
426+
The name of the table to which data will be inserted
427427
obj
428428
The source data or expression to insert
429429
database
@@ -526,6 +526,117 @@ def _build_insert_template(
526526
),
527527
).sql(self.dialect)
528528

529+
def upsert(
530+
self,
531+
name: str,
532+
/,
533+
obj: pd.DataFrame | ir.Table | list | dict,
534+
on: str,
535+
*,
536+
database: str | None = None,
537+
) -> None:
538+
"""Upsert data into a table.
539+
540+
::: {.callout-note}
541+
## Ibis does not use the word `schema` to refer to database hierarchy.
542+
543+
A collection of `table` is referred to as a `database`.
544+
A collection of `database` is referred to as a `catalog`.
545+
546+
These terms are mapped onto the corresponding features in each
547+
backend (where available), regardless of whether the backend itself
548+
uses the same terminology.
549+
:::
550+
551+
Parameters
552+
----------
553+
name
554+
The name of the table to which data will be upserted
555+
obj
556+
The source data or expression to upsert
557+
on
558+
Column name to join on
559+
database
560+
Name of the attached database that the table is located in.
561+
562+
For backends that support multi-level table hierarchies, you can
563+
pass in a dotted string path like `"catalog.database"` or a tuple of
564+
strings like `("catalog", "database")`.
565+
"""
566+
table_loc = self._to_sqlglot_table(database)
567+
catalog, db = self._to_catalog_db_tuple(table_loc)
568+
569+
if not isinstance(obj, ir.Table):
570+
obj = ibis.memtable(obj)
571+
572+
self._run_pre_execute_hooks(obj)
573+
574+
query = self._build_upsert_from_table(
575+
target=name, source=obj, on=on, db=db, catalog=catalog
576+
)
577+
578+
with self._safe_raw_sql(query):
579+
pass
580+
581+
def _build_upsert_from_table(
582+
self,
583+
*,
584+
target: str,
585+
source,
586+
on: str,
587+
db: str | None = None,
588+
catalog: str | None = None,
589+
):
590+
compiler = self.compiler
591+
quoted = compiler.quoted
592+
# Compare the columns between the target table and the object to be inserted
593+
# If source is a subset of target, use source columns for insert list
594+
# Otherwise, assume auto-generated column names and use positional ordering.
595+
target_cols = self.get_schema(target, catalog=catalog, database=db).keys()
596+
597+
columns = (
598+
source_cols
599+
if (source_cols := source.schema().keys()) <= target_cols
600+
else target_cols
601+
)
602+
603+
source_alias = util.gen_name("source")
604+
target_alias = util.gen_name("target")
605+
query = sge.merge(
606+
sge.When(
607+
matched=True,
608+
then=sge.Update(
609+
expressions=[
610+
sg.column(col, quoted=quoted).eq(
611+
sg.column(col, source_alias, quoted=quoted)
612+
)
613+
for col in columns
614+
]
615+
),
616+
),
617+
sge.When(
618+
matched=False,
619+
then=sge.Insert(
620+
this=sge.Tuple(expressions=columns),
621+
expression=sge.Tuple(
622+
expressions=[
623+
sg.column(col, source_alias, quoted=quoted)
624+
for col in columns
625+
]
626+
),
627+
),
628+
),
629+
into=sg.table(target, db=db, catalog=catalog, quoted=quoted).as_(
630+
target_alias
631+
),
632+
using=f"({self.compile(source)}) AS {source_alias}",
633+
on=sg.column(on, table=target_alias, quoted=quoted).eq(
634+
sg.column(on, table=source_alias, quoted=quoted)
635+
),
636+
dialect=compiler.dialect,
637+
)
638+
return query
639+
529640
def truncate_table(self, name: str, /, *, database: str | None = None) -> None:
530641
"""Delete all rows from a table.
531642

ibis/backends/tests/test_client.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,34 @@ def employee_data_2_temp_table(
519519
con.drop_table(temp_table_name, force=True)
520520

521521

522+
@pytest.fixture
523+
def test_employee_data_3():
524+
import pandas as pd
525+
526+
df3 = pd.DataFrame(
527+
{
528+
"first_name": ["B", "Y", "Z"],
529+
"last_name": ["A", "B", "C"],
530+
"department_name": ["XX", "YY", "ZZ"],
531+
"salary": [400.0, 500.0, 600.0],
532+
}
533+
)
534+
535+
return df3
536+
537+
538+
@pytest.fixture
539+
def employee_data_3_temp_table(
540+
backend, con, test_employee_schema, test_employee_data_3
541+
):
542+
temp_table_name = gen_name("temp_employee_data_3")
543+
_create_temp_table_with_schema(
544+
backend, con, temp_table_name, test_employee_schema, data=test_employee_data_3
545+
)
546+
yield temp_table_name
547+
con.drop_table(temp_table_name, force=True)
548+
549+
522550
@pytest.mark.notimpl(["polars"], reason="`insert` method not implemented")
523551
def test_insert_no_overwrite_from_dataframe(
524552
backend, con, test_employee_data_2, employee_empty_temp_table
@@ -626,6 +654,43 @@ def _emp(a, b, c, d):
626654
assert len(con.table(employee_data_1_temp_table).execute()) == 3
627655

628656

657+
@pytest.mark.notimpl(["polars"], reason="`upsert` method not implemented")
658+
def test_upsert_from_dataframe(
659+
backend, con, employee_data_1_temp_table, test_employee_data_3
660+
):
661+
temporary = con.table(employee_data_1_temp_table)
662+
df1 = temporary.execute().set_index("first_name")
663+
664+
con.upsert(employee_data_1_temp_table, obj=test_employee_data_3, on="first_name")
665+
result = temporary.execute()
666+
df2 = test_employee_data_3.set_index("first_name")
667+
expected = pd.concat([df1[~df1.index.isin(df2.index)], df2]).reset_index()
668+
assert len(result) == len(expected)
669+
backend.assert_frame_equal(
670+
result.sort_values("first_name").reset_index(drop=True),
671+
expected.sort_values("first_name").reset_index(drop=True),
672+
)
673+
674+
675+
@pytest.mark.notimpl(["polars"], reason="`upsert` method not implemented")
676+
def test_upsert_from_expr(
677+
backend, con, employee_data_1_temp_table, employee_data_3_temp_table
678+
):
679+
temporary = con.table(employee_data_1_temp_table)
680+
from_table = con.table(employee_data_3_temp_table)
681+
df1 = temporary.execute().set_index("first_name")
682+
683+
con.upsert(employee_data_1_temp_table, obj=from_table, on="first_name")
684+
result = temporary.execute()
685+
df2 = from_table.execute().set_index("first_name")
686+
expected = pd.concat([df1[~df1.index.isin(df2.index)], df2]).reset_index()
687+
assert len(result) == len(expected)
688+
backend.assert_frame_equal(
689+
result.sort_values("first_name").reset_index(drop=True),
690+
expected.sort_values("first_name").reset_index(drop=True),
691+
)
692+
693+
629694
@pytest.mark.notimpl(
630695
["polars"], raises=AttributeError, reason="`insert` method not implemented"
631696
)

0 commit comments

Comments
 (0)