Skip to content

Commit 2e60754

Browse files
authored
Add support for nullable fields in unique_together (ESSolutions#24)
1 parent 8bf0154 commit 2e60754

File tree

7 files changed

+254
-25
lines changed

7 files changed

+254
-25
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
2-
exclude = .git,__pycache__,
2+
exclude = .git,__pycache__,migrations
33
# W504 is mutually exclusive with W503
44
ignore = W504
55
max-line-length = 119

sql_server/pyodbc/base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -480,17 +480,13 @@ def check_constraints(self, table_names=None):
480480
table_names)
481481

482482
def disable_constraint_checking(self):
483-
# Azure SQL Database doesn't support sp_msforeachtable
484-
# cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"')
485483
if not self.needs_rollback:
486-
self._execute_foreach('ALTER TABLE %s NOCHECK CONSTRAINT ALL')
484+
self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"')
487485
return not self.needs_rollback
488486

489487
def enable_constraint_checking(self):
490-
# Azure SQL Database doesn't support sp_msforeachtable
491-
# cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"')
492488
if not self.needs_rollback:
493-
self.check_constraints()
489+
self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"')
494490

495491

496492
class CursorWrapper(object):

sql_server/pyodbc/features.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2323
supports_ignore_conflicts = False
2424
supports_index_on_text_field = False
2525
supports_paramstyle_pyformat = False
26-
supports_partially_nullable_unique_constraints = False
2726
supports_regex_backreferencing = False
2827
supports_sequence_reset = False
2928
supports_subqueries_in_group_by = False
@@ -41,6 +40,10 @@ def has_bulk_insert(self):
4140
def supports_nullable_unique_constraints(self):
4241
return self.connection.sql_server_version > 2005
4342

43+
@cached_property
44+
def supports_partially_nullable_unique_constraints(self):
45+
return self.connection.sql_server_version > 2005
46+
4447
@cached_property
4548
def supports_partial_indexes(self):
4649
return self.connection.sql_server_version > 2005

sql_server/pyodbc/schema.py

Lines changed: 183 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
BaseDatabaseSchemaEditor, logger, _is_relevant_relation, _related_non_m2m_objects,
66
)
77
from django.db.backends.ddl_references import (
8-
Statement,
8+
Columns, IndexName, Statement as DjStatement, Table,
99
)
1010
from django.db.models import Index
1111
from django.db.models.fields import AutoField, BigAutoField
1212
from django.db.transaction import TransactionManagementError
1313
from django.utils.encoding import force_text
1414

1515

16+
class Statement(DjStatement):
17+
def __hash__(self):
18+
return hash((self.template, str(self.parts['name'])))
19+
20+
def __eq__(self, other):
21+
return self.template == other.template and str(self.parts['name']) == str(other.parts['name'])
22+
23+
1624
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
1725

1826
_sql_check_constraint = " CONSTRAINT %(name)s CHECK (%(check)s)"
@@ -123,6 +131,105 @@ def _alter_column_type_sql(self, model, old_field, new_field, new_type):
123131
new_type = self._set_field_new_type_null_status(old_field, new_type)
124132
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
125133

134+
def alter_unique_together(self, model, old_unique_together, new_unique_together):
135+
"""
136+
Deal with a model changing its unique_together. The input
137+
unique_togethers must be doubly-nested, not the single-nested
138+
["foo", "bar"] format.
139+
"""
140+
olds = {tuple(fields) for fields in old_unique_together}
141+
news = {tuple(fields) for fields in new_unique_together}
142+
# Deleted uniques
143+
for fields in olds.difference(news):
144+
self._delete_composed_index(model, fields, {'unique': True}, self.sql_delete_index)
145+
# Created uniques
146+
for fields in news.difference(olds):
147+
columns = [model._meta.get_field(field).column for field in fields]
148+
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
149+
sql = self._create_unique_sql(model, columns, condition=condition)
150+
self.execute(sql)
151+
152+
def _model_indexes_sql(self, model):
153+
"""
154+
Return a list of all index SQL statements (field indexes,
155+
index_together, Meta.indexes) for the specified model.
156+
"""
157+
if not model._meta.managed or model._meta.proxy or model._meta.swapped:
158+
return []
159+
output = []
160+
for field in model._meta.local_fields:
161+
output.extend(self._field_indexes_sql(model, field))
162+
163+
for field_names in model._meta.index_together:
164+
fields = [model._meta.get_field(field) for field in field_names]
165+
output.append(self._create_index_sql(model, fields, suffix="_idx"))
166+
167+
for field_names in model._meta.unique_together:
168+
columns = [model._meta.get_field(field).column for field in field_names]
169+
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
170+
sql = self._create_unique_sql(model, columns, condition=condition)
171+
output.append(sql)
172+
173+
for index in model._meta.indexes:
174+
output.append(index.create_sql(model, self))
175+
return output
176+
177+
def _alter_many_to_many(self, model, old_field, new_field, strict):
178+
"""Alter M2Ms to repoint their to= endpoints."""
179+
180+
for idx in self._constraint_names(old_field.remote_field.through, index=True, unique=True):
181+
self.execute(self.sql_delete_index % {'name': idx, 'table': old_field.remote_field.through._meta.db_table})
182+
183+
return super()._alter_many_to_many(model, old_field, new_field, strict)
184+
185+
def _db_table_constraint_names(self, db_table, column_names=None, unique=None,
186+
primary_key=None, index=None, foreign_key=None,
187+
check=None, type_=None, exclude=None):
188+
"""Return all constraint names matching the columns and conditions."""
189+
if column_names is not None:
190+
column_names = [
191+
self.connection.introspection.identifier_converter(name)
192+
for name in column_names
193+
]
194+
with self.connection.cursor() as cursor:
195+
constraints = self.connection.introspection.get_constraints(cursor, db_table)
196+
result = []
197+
for name, infodict in constraints.items():
198+
if column_names is None or column_names == infodict['columns']:
199+
if unique is not None and infodict['unique'] != unique:
200+
continue
201+
if primary_key is not None and infodict['primary_key'] != primary_key:
202+
continue
203+
if index is not None and infodict['index'] != index:
204+
continue
205+
if check is not None and infodict['check'] != check:
206+
continue
207+
if foreign_key is not None and not infodict['foreign_key']:
208+
continue
209+
if type_ is not None and infodict['type'] != type_:
210+
continue
211+
if not exclude or name not in exclude:
212+
result.append(name)
213+
return result
214+
215+
def _db_table_delete_constraint_sql(self, template, db_table, name):
216+
return Statement(
217+
template,
218+
table=Table(db_table, self.quote_name),
219+
name=self.quote_name(name),
220+
)
221+
222+
def alter_db_table(self, model, old_db_table, new_db_table):
223+
index_names = self._db_table_constraint_names(old_db_table, index=True)
224+
for index_name in index_names:
225+
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, old_db_table, index_name))
226+
227+
index_names = self._db_table_constraint_names(new_db_table, index=True)
228+
for index_name in index_names:
229+
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, new_db_table, index_name))
230+
231+
return super().alter_db_table(model, old_db_table, new_db_table)
232+
126233
def _alter_field(self, model, old_field, new_field, old_type, new_type,
127234
old_db_params, new_db_params, strict=False):
128235
"""Actually perform a "physical" (non-ManyToMany) field update."""
@@ -224,11 +331,15 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
224331
self.execute(self._delete_constraint_sql(self.sql_delete_check, model, constraint_name))
225332
# Have they renamed the column?
226333
if old_field.column != new_field.column:
334+
# remove old indices
335+
self._delete_indexes(model, old_field, new_field)
336+
227337
self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
228338
# Rename all references to the renamed column.
229339
for sql in self.deferred_sql:
230-
if isinstance(sql, Statement):
340+
if isinstance(sql, DjStatement):
231341
sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column)
342+
232343
# Next, start accumulating actions to do
233344
actions = []
234345
null_actions = []
@@ -286,6 +397,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
286397
actions = [(", ".join(sql), sum(params, []))]
287398
# Apply those actions
288399
for sql, params in actions:
400+
self._delete_indexes(model, old_field, new_field)
289401
self.execute(
290402
self.sql_alter_column % {
291403
"table": self.quote_name(model._meta.db_table),
@@ -438,6 +550,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
438550
"changes": changes_sql,
439551
}
440552
self.execute(sql, params)
553+
441554
# Reset connection if required
442555
if self.connection.features.connection_persists_old_columns:
443556
self.connection.close()
@@ -446,11 +559,15 @@ def _delete_indexes(self, model, old_field, new_field):
446559
index_columns = []
447560
if old_field.db_index and new_field.db_index:
448561
index_columns.append([old_field.column])
449-
else:
450-
for fields in model._meta.index_together:
451-
columns = [model._meta.get_field(field).column for field in fields]
452-
if old_field.column in columns:
453-
index_columns.append(columns)
562+
for fields in model._meta.index_together:
563+
columns = [model._meta.get_field(field).column for field in fields]
564+
if old_field.column in columns:
565+
index_columns.append(columns)
566+
567+
for fields in model._meta.unique_together:
568+
columns = [model._meta.get_field(field).column for field in fields]
569+
if old_field.column in columns:
570+
index_columns.append(columns)
454571
if index_columns:
455572
for columns in index_columns:
456573
index_names = self._constraint_names(model, columns, index=True)
@@ -461,11 +578,6 @@ def _delete_unique_constraints(self, model, old_field, new_field, strict=False):
461578
unique_columns = []
462579
if old_field.unique and new_field.unique:
463580
unique_columns.append([old_field.column])
464-
else:
465-
for fields in model._meta.unique_together:
466-
columns = [model._meta.get_field(field).column for field in fields]
467-
if old_field.column in columns:
468-
unique_columns.append(columns)
469581
if unique_columns:
470582
for columns in unique_columns:
471583
constraint_names = self._constraint_names(model, columns, unique=True)
@@ -544,6 +656,61 @@ def add_field(self, model, field):
544656
if self.connection.features.connection_persists_old_columns:
545657
self.connection.close()
546658

659+
def _create_unique_sql(self, model, columns, name=None, condition=None):
660+
def create_unique_name(*args, **kwargs):
661+
return self.quote_name(self._create_index_name(*args, **kwargs))
662+
663+
table = Table(model._meta.db_table, self.quote_name)
664+
if name is None:
665+
name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
666+
else:
667+
name = self.quote_name(name)
668+
columns = Columns(table, columns, self.quote_name)
669+
if condition:
670+
return Statement(
671+
self.sql_create_unique_index,
672+
table=table,
673+
name=name,
674+
columns=columns,
675+
condition=' WHERE ' + condition,
676+
) if self.connection.features.supports_partial_indexes else None
677+
else:
678+
return Statement(
679+
self.sql_create_unique,
680+
table=table,
681+
name=name,
682+
columns=columns,
683+
)
684+
685+
def _create_index_sql(self, model, fields, *, name=None, suffix='', using='',
686+
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
687+
condition=None):
688+
"""
689+
Return the SQL statement to create the index for one or several fields.
690+
`sql` can be specified if the syntax differs from the standard (GIS
691+
indexes, ...).
692+
"""
693+
tablespace_sql = self._get_index_tablespace_sql(model, fields, db_tablespace=db_tablespace)
694+
columns = [field.column for field in fields]
695+
sql_create_index = sql or self.sql_create_index
696+
table = model._meta.db_table
697+
698+
def create_index_name(*args, **kwargs):
699+
nonlocal name
700+
if name is None:
701+
name = self._create_index_name(*args, **kwargs)
702+
return self.quote_name(name)
703+
704+
return Statement(
705+
sql_create_index,
706+
table=Table(table, self.quote_name),
707+
name=IndexName(table, columns, suffix, create_index_name),
708+
using=using,
709+
columns=self._index_columns(table, columns, col_suffixes, opclasses),
710+
extra=tablespace_sql,
711+
condition=(' WHERE ' + condition) if condition else '',
712+
)
713+
547714
def create_model(self, model):
548715
"""
549716
Takes a model and creates a table for it in the database.
@@ -605,7 +772,9 @@ def create_model(self, model):
605772
# created afterwards, like geometry fields with some backends)
606773
for fields in model._meta.unique_together:
607774
columns = [model._meta.get_field(field).column for field in fields]
608-
self.deferred_sql.append(self._create_unique_sql(model, columns))
775+
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
776+
self.deferred_sql.append(self._create_unique_sql(model, columns, condition=condition))
777+
609778
# Make the table
610779
sql = self.sql_create_table % {
611780
"table": self.quote_name(model._meta.db_table),
@@ -620,6 +789,7 @@ def create_model(self, model):
620789

621790
# Add any field index and index_together's (deferred as SQLite3 _remake_table needs it)
622791
self.deferred_sql.extend(self._model_indexes_sql(model))
792+
self.deferred_sql = list(set(self.deferred_sql))
623793

624794
# Make M2M tables
625795
for field in model._meta.local_many_to_many:

testapp/migrations/0001_initial.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,41 @@ class Migration(migrations.Migration):
1414
]
1515

1616
operations = [
17+
migrations.CreateModel(
18+
name='Author',
19+
fields=[
20+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
21+
('name', models.CharField(max_length=100)),
22+
],
23+
),
24+
migrations.CreateModel(
25+
name='Editor',
26+
fields=[
27+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
28+
('name', models.CharField(max_length=100)),
29+
],
30+
),
1731
migrations.CreateModel(
1832
name='Post',
1933
fields=[
2034
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
2135
('title', models.CharField(max_length=255, verbose_name='title')),
2236
],
2337
),
38+
migrations.AddField(
39+
model_name='post',
40+
name='alt_editor',
41+
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='testapp.Editor'),
42+
),
43+
migrations.AddField(
44+
model_name='post',
45+
name='author',
46+
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='testapp.Author'),
47+
),
48+
migrations.AlterUniqueTogether(
49+
name='post',
50+
unique_together={('author', 'title', 'alt_editor')},
51+
),
2452
migrations.CreateModel(
2553
name='Comment',
2654
fields=[

testapp/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,24 @@
44
from django.utils import timezone
55

66

7+
class Author(models.Model):
8+
name = models.CharField(max_length=100)
9+
10+
11+
class Editor(models.Model):
12+
name = models.CharField(max_length=100)
13+
14+
715
class Post(models.Model):
816
title = models.CharField('title', max_length=255)
17+
author = models.ForeignKey(Author, models.CASCADE)
18+
# Optional secondary author
19+
alt_editor = models.ForeignKey(Editor, models.SET_NULL, blank=True, null=True)
20+
21+
class Meta:
22+
unique_together = (
23+
('author', 'title', 'alt_editor'),
24+
)
925

1026
def __str__(self):
1127
return self.title

0 commit comments

Comments
 (0)