Skip to content

Commit ec1f538

Browse files
committed
Add ability to set update values, fixes #56
1 parent 4adc787 commit ec1f538

File tree

6 files changed

+201
-28
lines changed

6 files changed

+201
-28
lines changed

docs/source/conflict_handling.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,42 @@ Alternatively, with Django 3.1 or newer, :class:`~django:django.db.models.Q` obj
197197
Q(name__gt=ExcludedCol('priority'))
198198
199199
200+
Update values
201+
"""""""""""""
202+
203+
Optionally, the fields to update can be overriden. The default is to update the same fields that were specified in the rows to insert.
204+
205+
Refer to the insert values using the :class:`psqlextra.expressions.ExcludedCol` expression which translates to PostgreSQL's ``EXCLUDED.<column>`` expression. All expressions and features that can be used with Django's :meth:`~django:django.db.models.query.QuerySet.update` can be used here.
206+
207+
.. warning::
208+
209+
Specifying an empty ``update_values`` (``{}``) will transform the query into :attr:`~psqlextra.types.ConflictAction.NOTHING`. Only ``None`` makes the default behaviour kick in of updating all fields that were specified.
210+
211+
.. code-block:: python
212+
213+
from django.db.models import F
214+
215+
from psqlextra.expressions import ExcludedCol
216+
217+
(
218+
MyModel
219+
.objects
220+
.on_conflict(
221+
['name'],
222+
ConflictAction.UPDATE,
223+
update_values=dict(
224+
name=ExcludedCol('name'),
225+
count=F('count') + 1,
226+
),
227+
)
228+
.insert(
229+
name='henk',
230+
count=0,
231+
)
232+
)
233+
234+
235+
200236
ConflictAction.NOTHING
201237
**********************
202238

psqlextra/compiler.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,6 @@ def _rewrite_insert_on_conflict(
233233
"""Rewrites a normal SQL INSERT query to add the 'ON CONFLICT'
234234
clause."""
235235

236-
update_columns = ", ".join(
237-
[
238-
"{0} = EXCLUDED.{0}".format(self.qn(field.column))
239-
for field in self.query.update_fields
240-
]
241-
)
242-
243236
# build the conflict target, the columns to watch
244237
# for conflicts
245238
conflict_target = self._build_conflict_target()
@@ -253,10 +246,21 @@ def _rewrite_insert_on_conflict(
253246
rewritten_sql += f" WHERE {expr_sql}"
254247
params += tuple(expr_params)
255248

249+
# Fallback in case the user didn't specify any update values. We can still
250+
# make the query work if we switch to ConflictAction.NOTHING
251+
if (
252+
conflict_action == ConflictAction.UPDATE.value
253+
and not self.query.update_values
254+
):
255+
conflict_action = ConflictAction.NOTHING
256+
256257
rewritten_sql += f" DO {conflict_action}"
257258

258-
if conflict_action == "UPDATE":
259-
rewritten_sql += f" SET {update_columns}"
259+
if conflict_action == ConflictAction.UPDATE.value:
260+
set_sql, sql_params = self._build_set_statement()
261+
262+
rewritten_sql += f" SET {set_sql}"
263+
params += sql_params
260264

261265
if update_condition:
262266
expr_sql, expr_params = self._compile_expression(
@@ -269,6 +273,23 @@ def _rewrite_insert_on_conflict(
269273

270274
return (rewritten_sql, params)
271275

276+
def _build_set_statement(self) -> Tuple[str, tuple]:
277+
"""Builds the SET statement for the ON CONFLICT DO UPDATE clause.
278+
279+
This uses the update compiler to provide full compatibility with
280+
the standard Django's `update(...)`.
281+
"""
282+
283+
# Local import to work around the circular dependency between
284+
# the compiler and the queries.
285+
from .sql import PostgresUpdateQuery
286+
287+
query = self.query.chain(PostgresUpdateQuery)
288+
query.add_update_values(self.query.update_values)
289+
290+
sql, params = query.get_compiler(self.connection.alias).as_sql()
291+
return sql.split("SET")[1].split(" WHERE")[0], tuple(params)
292+
272293
def _build_conflict_target(self):
273294
"""Builds the `conflict_target` for the ON CONFLICT clause."""
274295

psqlextra/query.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from collections import OrderedDict
22
from itertools import chain
3-
from typing import Dict, Iterable, List, Optional, Tuple, Union
3+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
44

55
from django.core.exceptions import SuspiciousOperation
66
from django.db import connections, models, router
77
from django.db.models import Expression, Q
88
from django.db.models.fields import NOT_PROVIDED
99

10+
from .expressions import ExcludedCol
1011
from .sql import PostgresInsertQuery, PostgresQuery
1112
from .types import ConflictAction
1213

@@ -27,6 +28,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
2728
self.conflict_action = None
2829
self.conflict_update_condition = None
2930
self.index_predicate = None
31+
self.update_values = None
3032

3133
def annotate(self, **annotations):
3234
"""Custom version of the standard annotate function that allows using
@@ -84,6 +86,7 @@ def on_conflict(
8486
action: ConflictAction,
8587
index_predicate: Optional[Union[Expression, Q, str]] = None,
8688
update_condition: Optional[Union[Expression, Q, str]] = None,
89+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
8790
):
8891
"""Sets the action to take when conflicts arise when attempting to
8992
insert/create a new row.
@@ -101,12 +104,18 @@ def on_conflict(
101104
102105
update_condition:
103106
Only update if this SQL expression evaluates to true.
107+
108+
update_values:
109+
Optionally, values/expressions to use when rows
110+
conflict. If not specified, all columns specified
111+
in the rows are updated with the values you specified.
104112
"""
105113

106114
self.conflict_target = fields
107115
self.conflict_action = action
108116
self.conflict_update_condition = update_condition
109117
self.index_predicate = index_predicate
118+
self.update_values = update_values
110119

111120
return self
112121

@@ -260,6 +269,7 @@ def upsert(
260269
index_predicate: Optional[Union[Expression, Q, str]] = None,
261270
using: Optional[str] = None,
262271
update_condition: Optional[Union[Expression, Q, str]] = None,
272+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
263273
) -> int:
264274
"""Creates a new record or updates the existing one with the specified
265275
data.
@@ -282,6 +292,11 @@ def upsert(
282292
update_condition:
283293
Only update if this SQL expression evaluates to true.
284294
295+
update_values:
296+
Optionally, values/expressions to use when rows
297+
conflict. If not specified, all columns specified
298+
in the rows are updated with the values you specified.
299+
285300
Returns:
286301
The primary key of the row that was created/updated.
287302
"""
@@ -291,6 +306,7 @@ def upsert(
291306
ConflictAction.UPDATE,
292307
index_predicate=index_predicate,
293308
update_condition=update_condition,
309+
update_values=update_values,
294310
)
295311
return self.insert(**fields, using=using)
296312

@@ -301,6 +317,7 @@ def upsert_and_get(
301317
index_predicate: Optional[Union[Expression, Q, str]] = None,
302318
using: Optional[str] = None,
303319
update_condition: Optional[Union[Expression, Q, str]] = None,
320+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
304321
):
305322
"""Creates a new record or updates the existing one with the specified
306323
data and then gets the row.
@@ -323,6 +340,11 @@ def upsert_and_get(
323340
update_condition:
324341
Only update if this SQL expression evaluates to true.
325342
343+
update_values:
344+
Optionally, values/expressions to use when rows
345+
conflict. If not specified, all columns specified
346+
in the rows are updated with the values you specified.
347+
326348
Returns:
327349
The model instance representing the row
328350
that was created/updated.
@@ -333,6 +355,7 @@ def upsert_and_get(
333355
ConflictAction.UPDATE,
334356
index_predicate=index_predicate,
335357
update_condition=update_condition,
358+
update_values=update_values,
336359
)
337360
return self.insert_and_get(**fields, using=using)
338361

@@ -344,6 +367,7 @@ def bulk_upsert(
344367
return_model: bool = False,
345368
using: Optional[str] = None,
346369
update_condition: Optional[Union[Expression, Q, str]] = None,
370+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
347371
):
348372
"""Creates a set of new records or updates the existing ones with the
349373
specified data.
@@ -370,6 +394,11 @@ def bulk_upsert(
370394
update_condition:
371395
Only update if this SQL expression evaluates to true.
372396
397+
update_values:
398+
Optionally, values/expressions to use when rows
399+
conflict. If not specified, all columns specified
400+
in the rows are updated with the values you specified.
401+
373402
Returns:
374403
A list of either the dicts of the rows upserted, including the pk or
375404
the models of the rows upserted
@@ -386,7 +415,9 @@ def is_empty(r):
386415
ConflictAction.UPDATE,
387416
index_predicate=index_predicate,
388417
update_condition=update_condition,
418+
update_values=update_values,
389419
)
420+
390421
return self.bulk_insert(rows, return_model, using=using)
391422

392423
def _create_model_instance(
@@ -474,15 +505,19 @@ def _build_insert_compiler(
474505
)
475506

476507
# get the fields to be used during update/insert
477-
insert_fields, update_fields = self._get_upsert_fields(first_row)
508+
insert_fields, update_values = self._get_upsert_fields(first_row)
509+
510+
# allow the user to override what should happen on update
511+
if self.update_values is not None:
512+
update_values = self.update_values
478513

479514
# build a normal insert query
480515
query = PostgresInsertQuery(self.model)
481516
query.conflict_action = self.conflict_action
482517
query.conflict_target = self.conflict_target
483518
query.conflict_update_condition = self.conflict_update_condition
484519
query.index_predicate = self.index_predicate
485-
query.values(objs, insert_fields, update_fields)
520+
query.values(objs, insert_fields, update_values)
486521

487522
compiler = query.get_compiler(using)
488523
return compiler
@@ -547,13 +582,13 @@ def _get_upsert_fields(self, kwargs):
547582

548583
model_instance = self.model(**kwargs)
549584
insert_fields = []
550-
update_fields = []
585+
update_values = {}
551586

552587
for field in model_instance._meta.local_concrete_fields:
553588
has_default = field.default != NOT_PROVIDED
554589
if field.name in kwargs or field.column in kwargs:
555590
insert_fields.append(field)
556-
update_fields.append(field)
591+
update_values[field.name] = ExcludedCol(field.column)
557592
continue
558593
elif has_default:
559594
insert_fields.append(field)
@@ -564,13 +599,13 @@ def _get_upsert_fields(self, kwargs):
564599
# instead of a concrete field, we have to handle that
565600
if field.primary_key is True and "pk" in kwargs:
566601
insert_fields.append(field)
567-
update_fields.append(field)
602+
update_values[field.name] = ExcludedCol(field.column)
568603
continue
569604

570605
if self._is_magical_field(model_instance, field, is_insert=True):
571606
insert_fields.append(field)
572607

573608
if self._is_magical_field(model_instance, field, is_insert=False):
574-
update_fields.append(field)
609+
update_values[field.name] = ExcludedCol(field.column)
575610

576-
return insert_fields, update_fields
611+
return insert_fields, update_values

psqlextra/sql.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from collections import OrderedDict
2-
from typing import List, Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import django
55

66
from django.core.exceptions import SuspiciousOperation
77
from django.db import connections, models
8-
from django.db.models import sql
8+
from django.db.models import Expression, sql
99
from django.db.models.constants import LOOKUP_SEP
1010

1111
from .compiler import PostgresInsertOnConflictCompiler
@@ -148,10 +148,14 @@ def __init__(self, *args, **kwargs):
148148
self.conflict_action = ConflictAction.UPDATE
149149
self.conflict_update_condition = None
150150
self.index_predicate = None
151-
152-
self.update_fields = []
153-
154-
def values(self, objs: List, insert_fields: List, update_fields: List = []):
151+
self.update_values = {}
152+
153+
def values(
154+
self,
155+
objs: List,
156+
insert_fields: List,
157+
update_values: Dict[str, Union[Any, Expression]] = [],
158+
):
155159
"""Sets the values to be used in this query.
156160
157161
Insert fields are fields that are definitely
@@ -170,12 +174,13 @@ def values(self, objs: List, insert_fields: List, update_fields: List = []):
170174
insert_fields:
171175
The fields to use in the INSERT statement
172176
173-
update_fields:
174-
The fields to only use in the UPDATE statement.
177+
update_values:
178+
Expressions/values to use when a conflict
179+
occurs and an UPDATE is performed.
175180
"""
176181

177182
self.insert_values(insert_fields, objs, raw=False)
178-
self.update_fields = update_fields
183+
self.update_values = update_values
179184

180185
def get_compiler(self, using=None, connection=None):
181186
if using:

psqlextra/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class ConflictAction(Enum):
2828
def all(cls) -> List["ConflictAction"]:
2929
return [choice for choice in cls]
3030

31+
def __str__(self) -> str:
32+
return self.value
33+
3134

3235
class PostgresPartitioningMethod(StrEnum):
3336
"""Methods of partitioning supported by PostgreSQL 11.x native support for

0 commit comments

Comments
 (0)