Skip to content

Commit

Permalink
Modify bulk update default value (#341)
Browse files Browse the repository at this point in the history
* Modify bulk update default value

* Patch bulk_update_with_default and add tests

* Remove unused imports

* Change case from default to None
  • Loading branch information
dauinsight authored Jan 31, 2024
1 parent 0060aa8 commit 4df6db3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
7 changes: 4 additions & 3 deletions mssql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.core import validators
from django.db import NotSupportedError, connections, transaction
from django.db.models import BooleanField, CheckConstraint, Value
from django.db.models.expressions import Case, Exists, Expression, OrderBy, When, Window
from django.db.models.expressions import Case, Exists, OrderBy, When, Window
from django.db.models.fields import BinaryField, Field
from django.db.models.functions import Cast, NthValue, MD5, SHA1, SHA224, SHA256, SHA384, SHA512
from django.db.models.functions.datetime import Now
Expand Down Expand Up @@ -294,7 +294,7 @@ def _get_check_sql(self, model, schema_editor):
return sql % tuple(schema_editor.quote_value(p) for p in params)


def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
def bulk_update_with_default(self, objs, fields, batch_size=None, default=None):
"""
Update the given fields in each of the given objects in the database.
Expand Down Expand Up @@ -343,7 +343,8 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
attr = Value(attr, output_field=field)
when_statements.append(When(pk=obj.pk, then=attr))
if connection.vendor == 'microsoft' and value_none_counter == len(when_statements):
case_statement = Case(*when_statements, output_field=field, default=Value(default))
# We don't need a case statement if we are setting everything to None
case_statement = Value(None)
else:
case_statement = Case(*when_statements, output_field=field)
if requires_casting:
Expand Down
22 changes: 22 additions & 0 deletions testapp/migrations/0025_modelwithnullablefieldsofdifferenttypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Generated by Django 5.0.1 on 2024-01-29 14:18

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('testapp', '0024_publisher_book'),
]

operations = [
migrations.CreateModel(
name='ModelWithNullableFieldsOfDifferentTypes',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('int_value', models.IntegerField(null=True)),
('name', models.CharField(max_length=100, null=True)),
('date', models.DateTimeField(null=True)),
],
),
]
6 changes: 6 additions & 0 deletions testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class UUIDModel(models.Model):
def __str__(self):
return self.pk

class ModelWithNullableFieldsOfDifferentTypes(models.Model):
# Issue https://github.com/microsoft/mssql-django/issues/340
# Ensures the integrity of bulk updates with different types
int_value = models.IntegerField(null=True)
name = models.CharField(max_length=100, null=True)
date = models.DateTimeField(null=True)

class TestUniqueNullableModel(models.Model):
# Issue https://github.com/ESSolutions/django-mssql-backend/issues/38:
Expand Down
22 changes: 21 additions & 1 deletion testapp/tests/test_expressions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the BSD license.

import datetime
from unittest import skipUnless

from django import VERSION
Expand All @@ -9,7 +10,8 @@
from django.test import TestCase, skipUnlessDBFeature

from django.db.models.aggregates import Count
from ..models import Author, Comment, Post, Editor
from ..models import Author, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes


DJANGO3 = VERSION[0] >= 3

Expand Down Expand Up @@ -103,3 +105,21 @@ def test_order_by_nulls_first(self):
self.assertEqual(len(results), 2)
self.assertIsNone(results[0].alt_editor)
self.assertIsNotNone(results[1].alt_editor)

class TestBulkUpdate(TestCase):
def test_bulk_update_different_column_types(self):
data = (
(1, 'a', datetime.datetime(year=2024, month=1, day=1)),
(2, 'b', datetime.datetime(year=2023, month=12, day=31))
)
objs = ModelWithNullableFieldsOfDifferentTypes.objects.bulk_create(ModelWithNullableFieldsOfDifferentTypes(int_value=row_data[0],
name=row_data[1],
date=row_data[2]) for row_data in data)
for obj in objs:
obj.int_value = None
obj.name = None
obj.date = None
ModelWithNullableFieldsOfDifferentTypes.objects.bulk_update(objs, ["int_value", "name", "date"])
self.assertCountEqual(ModelWithNullableFieldsOfDifferentTypes.objects.filter(int_value__isnull=True), objs)
self.assertCountEqual(ModelWithNullableFieldsOfDifferentTypes.objects.filter(name__isnull=True), objs)
self.assertCountEqual(ModelWithNullableFieldsOfDifferentTypes.objects.filter(date__isnull=True), objs)

0 comments on commit 4df6db3

Please sign in to comment.