@@ -1191,6 +1191,7 @@ def simple_upsert_txn(
1191
1191
keyvalues : Dict [str , Any ],
1192
1192
values : Dict [str , Any ],
1193
1193
insertion_values : Optional [Dict [str , Any ]] = None ,
1194
+ where_clause : Optional [str ] = None ,
1194
1195
lock : bool = True ,
1195
1196
) -> bool :
1196
1197
"""
@@ -1203,6 +1204,7 @@ def simple_upsert_txn(
1203
1204
keyvalues: The unique key tables and their new values
1204
1205
values: The nonunique columns and their new values
1205
1206
insertion_values: additional key/values to use only when inserting
1207
+ where_clause: An index predicate to apply to the upsert.
1206
1208
lock: True to lock the table when doing the upsert. Unused when performing
1207
1209
a native upsert.
1208
1210
Returns:
@@ -1213,7 +1215,12 @@ def simple_upsert_txn(
1213
1215
1214
1216
if table not in self ._unsafe_to_upsert_tables :
1215
1217
return self .simple_upsert_txn_native_upsert (
1216
- txn , table , keyvalues , values , insertion_values = insertion_values
1218
+ txn ,
1219
+ table ,
1220
+ keyvalues ,
1221
+ values ,
1222
+ insertion_values = insertion_values ,
1223
+ where_clause = where_clause ,
1217
1224
)
1218
1225
else :
1219
1226
return self .simple_upsert_txn_emulated (
@@ -1222,6 +1229,7 @@ def simple_upsert_txn(
1222
1229
keyvalues ,
1223
1230
values ,
1224
1231
insertion_values = insertion_values ,
1232
+ where_clause = where_clause ,
1225
1233
lock = lock ,
1226
1234
)
1227
1235
@@ -1232,6 +1240,7 @@ def simple_upsert_txn_emulated(
1232
1240
keyvalues : Dict [str , Any ],
1233
1241
values : Dict [str , Any ],
1234
1242
insertion_values : Optional [Dict [str , Any ]] = None ,
1243
+ where_clause : Optional [str ] = None ,
1235
1244
lock : bool = True ,
1236
1245
) -> bool :
1237
1246
"""
@@ -1240,6 +1249,7 @@ def simple_upsert_txn_emulated(
1240
1249
keyvalues: The unique key tables and their new values
1241
1250
values: The nonunique columns and their new values
1242
1251
insertion_values: additional key/values to use only when inserting
1252
+ where_clause: An index predicate to apply to the upsert.
1243
1253
lock: True to lock the table when doing the upsert.
1244
1254
Returns:
1245
1255
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1259,14 +1269,17 @@ def _getwhere(key: str) -> str:
1259
1269
else :
1260
1270
return "%s = ?" % (key ,)
1261
1271
1272
+ # Generate a where clause of each keyvalue and optionally the provided
1273
+ # index predicate.
1274
+ where = [_getwhere (k ) for k in keyvalues ]
1275
+ if where_clause :
1276
+ where .append (where_clause )
1277
+
1262
1278
if not values :
1263
1279
# If `values` is empty, then all of the values we care about are in
1264
1280
# the unique key, so there is nothing to UPDATE. We can just do a
1265
1281
# SELECT instead to see if it exists.
1266
- sql = "SELECT 1 FROM %s WHERE %s" % (
1267
- table ,
1268
- " AND " .join (_getwhere (k ) for k in keyvalues ),
1269
- )
1282
+ sql = "SELECT 1 FROM %s WHERE %s" % (table , " AND " .join (where ))
1270
1283
sqlargs = list (keyvalues .values ())
1271
1284
txn .execute (sql , sqlargs )
1272
1285
if txn .fetchall ():
@@ -1277,7 +1290,7 @@ def _getwhere(key: str) -> str:
1277
1290
sql = "UPDATE %s SET %s WHERE %s" % (
1278
1291
table ,
1279
1292
", " .join ("%s = ?" % (k ,) for k in values ),
1280
- " AND " .join (_getwhere ( k ) for k in keyvalues ),
1293
+ " AND " .join (where ),
1281
1294
)
1282
1295
sqlargs = list (values .values ()) + list (keyvalues .values ())
1283
1296
@@ -1307,6 +1320,7 @@ def simple_upsert_txn_native_upsert(
1307
1320
keyvalues : Dict [str , Any ],
1308
1321
values : Dict [str , Any ],
1309
1322
insertion_values : Optional [Dict [str , Any ]] = None ,
1323
+ where_clause : Optional [str ] = None ,
1310
1324
) -> bool :
1311
1325
"""
1312
1326
Use the native UPSERT functionality in PostgreSQL.
@@ -1316,6 +1330,7 @@ def simple_upsert_txn_native_upsert(
1316
1330
keyvalues: The unique key tables and their new values
1317
1331
values: The nonunique columns and their new values
1318
1332
insertion_values: additional key/values to use only when inserting
1333
+ where_clause: An index predicate to apply to the upsert.
1319
1334
1320
1335
Returns:
1321
1336
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1331,11 +1346,12 @@ def simple_upsert_txn_native_upsert(
1331
1346
allvalues .update (values )
1332
1347
latter = "UPDATE SET " + ", " .join (k + "=EXCLUDED." + k for k in values )
1333
1348
1334
- sql = ( "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" ) % (
1349
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
1335
1350
table ,
1336
1351
", " .join (k for k in allvalues ),
1337
1352
", " .join ("?" for _ in allvalues ),
1338
1353
", " .join (k for k in keyvalues ),
1354
+ f"WHERE { where_clause } " if where_clause else "" ,
1339
1355
latter ,
1340
1356
)
1341
1357
txn .execute (sql , list (allvalues .values ()))
0 commit comments