Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit b2b0c85

Browse files
authored
Support providing an index predicate for upserts. (#13822)
This is useful to upsert against a table which has a unique partial index while avoiding conflicts.
1 parent 742f9f9 commit b2b0c85

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

changelog.d/13822.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support providing an index predicate clause when doing upserts.

synapse/storage/background_updates.py

+1
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def register_background_index_update(
533533
index_name: name of index to add
534534
table: table to add index to
535535
columns: columns/expressions to include in index
536+
where_clause: A WHERE clause to specify a partial unique index.
536537
unique: true to make a UNIQUE index
537538
psql_only: true to only create this index on psql databases (useful
538539
for virtual sqlite tables)

synapse/storage/database.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,7 @@ def simple_upsert_txn(
11911191
keyvalues: Dict[str, Any],
11921192
values: Dict[str, Any],
11931193
insertion_values: Optional[Dict[str, Any]] = None,
1194+
where_clause: Optional[str] = None,
11941195
lock: bool = True,
11951196
) -> bool:
11961197
"""
@@ -1203,6 +1204,7 @@ def simple_upsert_txn(
12031204
keyvalues: The unique key tables and their new values
12041205
values: The nonunique columns and their new values
12051206
insertion_values: additional key/values to use only when inserting
1207+
where_clause: An index predicate to apply to the upsert.
12061208
lock: True to lock the table when doing the upsert. Unused when performing
12071209
a native upsert.
12081210
Returns:
@@ -1213,7 +1215,12 @@ def simple_upsert_txn(
12131215

12141216
if table not in self._unsafe_to_upsert_tables:
12151217
return self.simple_upsert_txn_native_upsert(
1216-
txn, table, keyvalues, values, insertion_values=insertion_values
1218+
txn,
1219+
table,
1220+
keyvalues,
1221+
values,
1222+
insertion_values=insertion_values,
1223+
where_clause=where_clause,
12171224
)
12181225
else:
12191226
return self.simple_upsert_txn_emulated(
@@ -1222,6 +1229,7 @@ def simple_upsert_txn(
12221229
keyvalues,
12231230
values,
12241231
insertion_values=insertion_values,
1232+
where_clause=where_clause,
12251233
lock=lock,
12261234
)
12271235

@@ -1232,6 +1240,7 @@ def simple_upsert_txn_emulated(
12321240
keyvalues: Dict[str, Any],
12331241
values: Dict[str, Any],
12341242
insertion_values: Optional[Dict[str, Any]] = None,
1243+
where_clause: Optional[str] = None,
12351244
lock: bool = True,
12361245
) -> bool:
12371246
"""
@@ -1240,6 +1249,7 @@ def simple_upsert_txn_emulated(
12401249
keyvalues: The unique key tables and their new values
12411250
values: The nonunique columns and their new values
12421251
insertion_values: additional key/values to use only when inserting
1252+
where_clause: An index predicate to apply to the upsert.
12431253
lock: True to lock the table when doing the upsert.
12441254
Returns:
12451255
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1259,14 +1269,17 @@ def _getwhere(key: str) -> str:
12591269
else:
12601270
return "%s = ?" % (key,)
12611271

1272+
# Generate a where clause of each keyvalue and optionally the provided
1273+
# index predicate.
1274+
where = [_getwhere(k) for k in keyvalues]
1275+
if where_clause:
1276+
where.append(where_clause)
1277+
12621278
if not values:
12631279
# If `values` is empty, then all of the values we care about are in
12641280
# the unique key, so there is nothing to UPDATE. We can just do a
12651281
# SELECT instead to see if it exists.
1266-
sql = "SELECT 1 FROM %s WHERE %s" % (
1267-
table,
1268-
" AND ".join(_getwhere(k) for k in keyvalues),
1269-
)
1282+
sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
12701283
sqlargs = list(keyvalues.values())
12711284
txn.execute(sql, sqlargs)
12721285
if txn.fetchall():
@@ -1277,7 +1290,7 @@ def _getwhere(key: str) -> str:
12771290
sql = "UPDATE %s SET %s WHERE %s" % (
12781291
table,
12791292
", ".join("%s = ?" % (k,) for k in values),
1280-
" AND ".join(_getwhere(k) for k in keyvalues),
1293+
" AND ".join(where),
12811294
)
12821295
sqlargs = list(values.values()) + list(keyvalues.values())
12831296

@@ -1307,6 +1320,7 @@ def simple_upsert_txn_native_upsert(
13071320
keyvalues: Dict[str, Any],
13081321
values: Dict[str, Any],
13091322
insertion_values: Optional[Dict[str, Any]] = None,
1323+
where_clause: Optional[str] = None,
13101324
) -> bool:
13111325
"""
13121326
Use the native UPSERT functionality in PostgreSQL.
@@ -1316,6 +1330,7 @@ def simple_upsert_txn_native_upsert(
13161330
keyvalues: The unique key tables and their new values
13171331
values: The nonunique columns and their new values
13181332
insertion_values: additional key/values to use only when inserting
1333+
where_clause: An index predicate to apply to the upsert.
13191334
13201335
Returns:
13211336
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1331,11 +1346,12 @@ def simple_upsert_txn_native_upsert(
13311346
allvalues.update(values)
13321347
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
13331348

1334-
sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
1349+
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
13351350
table,
13361351
", ".join(k for k in allvalues),
13371352
", ".join("?" for _ in allvalues),
13381353
", ".join(k for k in keyvalues),
1354+
f"WHERE {where_clause}" if where_clause else "",
13391355
latter,
13401356
)
13411357
txn.execute(sql, list(allvalues.values()))

0 commit comments

Comments
 (0)