Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Support providing an index predicate for upserts. (#13822)
Browse files Browse the repository at this point in the history
This is useful to upsert against a table which has a unique
partial index while avoiding conflicts.
  • Loading branch information
clokep authored Sep 15, 2022
1 parent 742f9f9 commit b2b0c85
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
1 change: 1 addition & 0 deletions changelog.d/13822.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support providing an index predicate clause when doing upserts.
1 change: 1 addition & 0 deletions synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def register_background_index_update(
index_name: name of index to add
table: table to add index to
columns: columns/expressions to include in index
where_clause: A WHERE clause to specify a partial unique index.
unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
Expand Down
30 changes: 23 additions & 7 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ def simple_upsert_txn(
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
Expand All @@ -1203,6 +1204,7 @@ def simple_upsert_txn(
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert. Unused when performing
a native upsert.
Returns:
Expand All @@ -1213,7 +1215,12 @@ def simple_upsert_txn(

if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
)
else:
return self.simple_upsert_txn_emulated(
Expand All @@ -1222,6 +1229,7 @@ def simple_upsert_txn(
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
lock=lock,
)

Expand All @@ -1232,6 +1240,7 @@ def simple_upsert_txn_emulated(
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
Expand All @@ -1240,6 +1249,7 @@ def simple_upsert_txn_emulated(
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
Expand All @@ -1259,14 +1269,17 @@ def _getwhere(key: str) -> str:
else:
return "%s = ?" % (key,)

# Generate a where clause of each keyvalue and optionally the provided
# index predicate.
where = [_getwhere(k) for k in keyvalues]
if where_clause:
where.append(where_clause)

if not values:
# If `values` is empty, then all of the values we care about are in
# the unique key, so there is nothing to UPDATE. We can just do a
# SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % (
table,
" AND ".join(_getwhere(k) for k in keyvalues),
)
sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.fetchall():
Expand All @@ -1277,7 +1290,7 @@ def _getwhere(key: str) -> str:
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join(_getwhere(k) for k in keyvalues),
" AND ".join(where),
)
sqlargs = list(values.values()) + list(keyvalues.values())

Expand Down Expand Up @@ -1307,6 +1320,7 @@ def simple_upsert_txn_native_upsert(
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
) -> bool:
"""
Use the native UPSERT functionality in PostgreSQL.
Expand All @@ -1316,6 +1330,7 @@ def simple_upsert_txn_native_upsert(
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
Expand All @@ -1331,11 +1346,12 @@ def simple_upsert_txn_native_upsert(
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)

sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
f"WHERE {where_clause}" if where_clause else "",
latter,
)
txn.execute(sql, list(allvalues.values()))
Expand Down

0 comments on commit b2b0c85

Please sign in to comment.