1
1
from collections import OrderedDict
2
2
from itertools import chain
3
- from typing import Dict , Iterable , List , Optional , Tuple , Union
3
+ from typing import Any , Dict , Iterable , List , Optional , Tuple , Union
4
4
5
5
from django .core .exceptions import SuspiciousOperation
6
6
from django .db import connections , models , router
7
7
from django .db .models import Expression , Q
8
8
from django .db .models .fields import NOT_PROVIDED
9
9
10
+ from .expressions import ExcludedCol
10
11
from .sql import PostgresInsertQuery , PostgresQuery
11
12
from .types import ConflictAction
12
13
@@ -27,6 +28,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
27
28
self .conflict_action = None
28
29
self .conflict_update_condition = None
29
30
self .index_predicate = None
31
+ self .update_values = None
30
32
31
33
def annotate (self , ** annotations ):
32
34
"""Custom version of the standard annotate function that allows using
@@ -84,6 +86,7 @@ def on_conflict(
84
86
action : ConflictAction ,
85
87
index_predicate : Optional [Union [Expression , Q , str ]] = None ,
86
88
update_condition : Optional [Union [Expression , Q , str ]] = None ,
89
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
87
90
):
88
91
"""Sets the action to take when conflicts arise when attempting to
89
92
insert/create a new row.
@@ -101,12 +104,18 @@ def on_conflict(
101
104
102
105
update_condition:
103
106
Only update if this SQL expression evaluates to true.
107
+
108
+ update_values:
109
+ Optionally, values/expressions to use when rows
110
+ conflict. If not specified, all columns specified
111
+ in the rows are updated with the values you specified.
104
112
"""
105
113
106
114
self .conflict_target = fields
107
115
self .conflict_action = action
108
116
self .conflict_update_condition = update_condition
109
117
self .index_predicate = index_predicate
118
+ self .update_values = update_values
110
119
111
120
return self
112
121
@@ -260,6 +269,7 @@ def upsert(
260
269
index_predicate : Optional [Union [Expression , Q , str ]] = None ,
261
270
using : Optional [str ] = None ,
262
271
update_condition : Optional [Union [Expression , Q , str ]] = None ,
272
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
263
273
) -> int :
264
274
"""Creates a new record or updates the existing one with the specified
265
275
data.
@@ -282,6 +292,11 @@ def upsert(
282
292
update_condition:
283
293
Only update if this SQL expression evaluates to true.
284
294
295
+ update_values:
296
+ Optionally, values/expressions to use when rows
297
+ conflict. If not specified, all columns specified
298
+ in the rows are updated with the values you specified.
299
+
285
300
Returns:
286
301
The primary key of the row that was created/updated.
287
302
"""
@@ -291,6 +306,7 @@ def upsert(
291
306
ConflictAction .UPDATE ,
292
307
index_predicate = index_predicate ,
293
308
update_condition = update_condition ,
309
+ update_values = update_values ,
294
310
)
295
311
return self .insert (** fields , using = using )
296
312
@@ -301,6 +317,7 @@ def upsert_and_get(
301
317
index_predicate : Optional [Union [Expression , Q , str ]] = None ,
302
318
using : Optional [str ] = None ,
303
319
update_condition : Optional [Union [Expression , Q , str ]] = None ,
320
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
304
321
):
305
322
"""Creates a new record or updates the existing one with the specified
306
323
data and then gets the row.
@@ -323,6 +340,11 @@ def upsert_and_get(
323
340
update_condition:
324
341
Only update if this SQL expression evaluates to true.
325
342
343
+ update_values:
344
+ Optionally, values/expressions to use when rows
345
+ conflict. If not specified, all columns specified
346
+ in the rows are updated with the values you specified.
347
+
326
348
Returns:
327
349
The model instance representing the row
328
350
that was created/updated.
@@ -333,6 +355,7 @@ def upsert_and_get(
333
355
ConflictAction .UPDATE ,
334
356
index_predicate = index_predicate ,
335
357
update_condition = update_condition ,
358
+ update_values = update_values ,
336
359
)
337
360
return self .insert_and_get (** fields , using = using )
338
361
@@ -344,6 +367,7 @@ def bulk_upsert(
344
367
return_model : bool = False ,
345
368
using : Optional [str ] = None ,
346
369
update_condition : Optional [Union [Expression , Q , str ]] = None ,
370
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
347
371
):
348
372
"""Creates a set of new records or updates the existing ones with the
349
373
specified data.
@@ -370,6 +394,11 @@ def bulk_upsert(
370
394
update_condition:
371
395
Only update if this SQL expression evaluates to true.
372
396
397
+ update_values:
398
+ Optionally, values/expressions to use when rows
399
+ conflict. If not specified, all columns specified
400
+ in the rows are updated with the values you specified.
401
+
373
402
Returns:
374
403
A list of either the dicts of the rows upserted, including the pk or
375
404
the models of the rows upserted
@@ -386,7 +415,9 @@ def is_empty(r):
386
415
ConflictAction .UPDATE ,
387
416
index_predicate = index_predicate ,
388
417
update_condition = update_condition ,
418
+ update_values = update_values ,
389
419
)
420
+
390
421
return self .bulk_insert (rows , return_model , using = using )
391
422
392
423
def _create_model_instance (
@@ -474,15 +505,19 @@ def _build_insert_compiler(
474
505
)
475
506
476
507
# get the fields to be used during update/insert
477
- insert_fields , update_fields = self ._get_upsert_fields (first_row )
508
+ insert_fields , update_values = self ._get_upsert_fields (first_row )
509
+
510
+ # allow the user to override what should happen on update
511
+ if self .update_values is not None :
512
+ update_values = self .update_values
478
513
479
514
# build a normal insert query
480
515
query = PostgresInsertQuery (self .model )
481
516
query .conflict_action = self .conflict_action
482
517
query .conflict_target = self .conflict_target
483
518
query .conflict_update_condition = self .conflict_update_condition
484
519
query .index_predicate = self .index_predicate
485
- query .values (objs , insert_fields , update_fields )
520
+ query .values (objs , insert_fields , update_values )
486
521
487
522
compiler = query .get_compiler (using )
488
523
return compiler
@@ -547,13 +582,13 @@ def _get_upsert_fields(self, kwargs):
547
582
548
583
model_instance = self .model (** kwargs )
549
584
insert_fields = []
550
- update_fields = []
585
+ update_values = {}
551
586
552
587
for field in model_instance ._meta .local_concrete_fields :
553
588
has_default = field .default != NOT_PROVIDED
554
589
if field .name in kwargs or field .column in kwargs :
555
590
insert_fields .append (field )
556
- update_fields . append (field )
591
+ update_values [ field . name ] = ExcludedCol (field . column )
557
592
continue
558
593
elif has_default :
559
594
insert_fields .append (field )
@@ -564,13 +599,13 @@ def _get_upsert_fields(self, kwargs):
564
599
# instead of a concrete field, we have to handle that
565
600
if field .primary_key is True and "pk" in kwargs :
566
601
insert_fields .append (field )
567
- update_fields . append (field )
602
+ update_values [ field . name ] = ExcludedCol (field . column )
568
603
continue
569
604
570
605
if self ._is_magical_field (model_instance , field , is_insert = True ):
571
606
insert_fields .append (field )
572
607
573
608
if self ._is_magical_field (model_instance , field , is_insert = False ):
574
- update_fields . append (field )
609
+ update_values [ field . name ] = ExcludedCol (field . column )
575
610
576
- return insert_fields , update_fields
611
+ return insert_fields , update_values
0 commit comments