Skip to content

Add support for nullable fields in unique_together #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
exclude = .git,__pycache__,
exclude = .git,__pycache__,migrations
# W504 is mutually exclusive with W503
ignore = W504
max-line-length = 119
8 changes: 2 additions & 6 deletions sql_server/pyodbc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,17 +480,13 @@ def check_constraints(self, table_names=None):
table_names)

def disable_constraint_checking(self):
# Azure SQL Database doesn't support sp_msforeachtable
# cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"')
if not self.needs_rollback:
self._execute_foreach('ALTER TABLE %s NOCHECK CONSTRAINT ALL')
self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"')
return not self.needs_rollback

def enable_constraint_checking(self):
# Azure SQL Database doesn't support sp_msforeachtable
# cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"')
if not self.needs_rollback:
self.check_constraints()
self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"')


class CursorWrapper(object):
Expand Down
5 changes: 4 additions & 1 deletion sql_server/pyodbc/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_ignore_conflicts = False
supports_index_on_text_field = False
supports_paramstyle_pyformat = False
supports_partially_nullable_unique_constraints = False
supports_regex_backreferencing = False
supports_sequence_reset = False
supports_subqueries_in_group_by = False
Expand All @@ -41,6 +40,10 @@ def has_bulk_insert(self):
def supports_nullable_unique_constraints(self):
return self.connection.sql_server_version > 2005

@cached_property
def supports_partially_nullable_unique_constraints(self):
return self.connection.sql_server_version > 2005

@cached_property
def supports_partial_indexes(self):
return self.connection.sql_server_version > 2005
Expand Down
196 changes: 183 additions & 13 deletions sql_server/pyodbc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@
BaseDatabaseSchemaEditor, logger, _is_relevant_relation, _related_non_m2m_objects,
)
from django.db.backends.ddl_references import (
Statement,
Columns, IndexName, Statement as DjStatement, Table,
)
from django.db.models import Index
from django.db.models.fields import AutoField, BigAutoField
from django.db.transaction import TransactionManagementError
from django.utils.encoding import force_text


class Statement(DjStatement):
def __hash__(self):
return hash((self.template, str(self.parts['name'])))

def __eq__(self, other):
return self.template == other.template and str(self.parts['name']) == str(other.parts['name'])


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):

_sql_check_constraint = " CONSTRAINT %(name)s CHECK (%(check)s)"
Expand Down Expand Up @@ -123,6 +131,105 @@ def _alter_column_type_sql(self, model, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)

def alter_unique_together(self, model, old_unique_together, new_unique_together):
"""
Deal with a model changing its unique_together. The input
unique_togethers must be doubly-nested, not the single-nested
["foo", "bar"] format.
"""
olds = {tuple(fields) for fields in old_unique_together}
news = {tuple(fields) for fields in new_unique_together}
# Deleted uniques
for fields in olds.difference(news):
self._delete_composed_index(model, fields, {'unique': True}, self.sql_delete_index)
# Created uniques
for fields in news.difference(olds):
columns = [model._meta.get_field(field).column for field in fields]
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
sql = self._create_unique_sql(model, columns, condition=condition)
self.execute(sql)

def _model_indexes_sql(self, model):
"""
Return a list of all index SQL statements (field indexes,
index_together, Meta.indexes) for the specified model.
"""
if not model._meta.managed or model._meta.proxy or model._meta.swapped:
return []
output = []
for field in model._meta.local_fields:
output.extend(self._field_indexes_sql(model, field))

for field_names in model._meta.index_together:
fields = [model._meta.get_field(field) for field in field_names]
output.append(self._create_index_sql(model, fields, suffix="_idx"))

for field_names in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in field_names]
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
sql = self._create_unique_sql(model, columns, condition=condition)
output.append(sql)

for index in model._meta.indexes:
output.append(index.create_sql(model, self))
return output

def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""

for idx in self._constraint_names(old_field.remote_field.through, index=True, unique=True):
self.execute(self.sql_delete_index % {'name': idx, 'table': old_field.remote_field.through._meta.db_table})

return super()._alter_many_to_many(model, old_field, new_field, strict)

def _db_table_constraint_names(self, db_table, column_names=None, unique=None,
primary_key=None, index=None, foreign_key=None,
check=None, type_=None, exclude=None):
"""Return all constraint names matching the columns and conditions."""
if column_names is not None:
column_names = [
self.connection.introspection.identifier_converter(name)
for name in column_names
]
with self.connection.cursor() as cursor:
constraints = self.connection.introspection.get_constraints(cursor, db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:
if unique is not None and infodict['unique'] != unique:
continue
if primary_key is not None and infodict['primary_key'] != primary_key:
continue
if index is not None and infodict['index'] != index:
continue
if check is not None and infodict['check'] != check:
continue
if foreign_key is not None and not infodict['foreign_key']:
continue
if type_ is not None and infodict['type'] != type_:
continue
if not exclude or name not in exclude:
result.append(name)
return result

def _db_table_delete_constraint_sql(self, template, db_table, name):
return Statement(
template,
table=Table(db_table, self.quote_name),
name=self.quote_name(name),
)

def alter_db_table(self, model, old_db_table, new_db_table):
index_names = self._db_table_constraint_names(old_db_table, index=True)
for index_name in index_names:
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, old_db_table, index_name))

index_names = self._db_table_constraint_names(new_db_table, index=True)
for index_name in index_names:
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, new_db_table, index_name))

return super().alter_db_table(model, old_db_table, new_db_table)

def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
"""Actually perform a "physical" (non-ManyToMany) field update."""
Expand Down Expand Up @@ -224,11 +331,15 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
self.execute(self._delete_constraint_sql(self.sql_delete_check, model, constraint_name))
# Have they renamed the column?
if old_field.column != new_field.column:
# remove old indices
self._delete_indexes(model, old_field, new_field)

self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
# Rename all references to the renamed column.
for sql in self.deferred_sql:
if isinstance(sql, Statement):
if isinstance(sql, DjStatement):
sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column)

# Next, start accumulating actions to do
actions = []
null_actions = []
Expand Down Expand Up @@ -286,6 +397,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
actions = [(", ".join(sql), sum(params, []))]
# Apply those actions
for sql, params in actions:
self._delete_indexes(model, old_field, new_field)
self.execute(
self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
Expand Down Expand Up @@ -438,6 +550,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
"changes": changes_sql,
}
self.execute(sql, params)

# Reset connection if required
if self.connection.features.connection_persists_old_columns:
self.connection.close()
Expand All @@ -446,11 +559,15 @@ def _delete_indexes(self, model, old_field, new_field):
index_columns = []
if old_field.db_index and new_field.db_index:
index_columns.append([old_field.column])
else:
for fields in model._meta.index_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
index_columns.append(columns)
for fields in model._meta.index_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
index_columns.append(columns)

for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
index_columns.append(columns)
if index_columns:
for columns in index_columns:
index_names = self._constraint_names(model, columns, index=True)
Expand All @@ -461,11 +578,6 @@ def _delete_unique_constraints(self, model, old_field, new_field, strict=False):
unique_columns = []
if old_field.unique and new_field.unique:
unique_columns.append([old_field.column])
else:
for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
unique_columns.append(columns)
if unique_columns:
for columns in unique_columns:
constraint_names = self._constraint_names(model, columns, unique=True)
Expand Down Expand Up @@ -544,6 +656,61 @@ def add_field(self, model, field):
if self.connection.features.connection_persists_old_columns:
self.connection.close()

def _create_unique_sql(self, model, columns, name=None, condition=None):
def create_unique_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs))

table = Table(model._meta.db_table, self.quote_name)
if name is None:
name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
else:
name = self.quote_name(name)
columns = Columns(table, columns, self.quote_name)
if condition:
return Statement(
self.sql_create_unique_index,
table=table,
name=name,
columns=columns,
condition=' WHERE ' + condition,
) if self.connection.features.supports_partial_indexes else None
else:
return Statement(
self.sql_create_unique,
table=table,
name=name,
columns=columns,
)

def _create_index_sql(self, model, fields, *, name=None, suffix='', using='',
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
condition=None):
"""
Return the SQL statement to create the index for one or several fields.
`sql` can be specified if the syntax differs from the standard (GIS
indexes, ...).
"""
tablespace_sql = self._get_index_tablespace_sql(model, fields, db_tablespace=db_tablespace)
columns = [field.column for field in fields]
sql_create_index = sql or self.sql_create_index
table = model._meta.db_table

def create_index_name(*args, **kwargs):
nonlocal name
if name is None:
name = self._create_index_name(*args, **kwargs)
return self.quote_name(name)

return Statement(
sql_create_index,
table=Table(table, self.quote_name),
name=IndexName(table, columns, suffix, create_index_name),
using=using,
columns=self._index_columns(table, columns, col_suffixes, opclasses),
extra=tablespace_sql,
condition=(' WHERE ' + condition) if condition else '',
)

def create_model(self, model):
"""
Takes a model and creates a table for it in the database.
Expand Down Expand Up @@ -605,7 +772,9 @@ def create_model(self, model):
# created afterwards, like geometry fields with some backends)
for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
self.deferred_sql.append(self._create_unique_sql(model, columns))
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
self.deferred_sql.append(self._create_unique_sql(model, columns, condition=condition))

# Make the table
sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table),
Expand All @@ -620,6 +789,7 @@ def create_model(self, model):

# Add any field index and index_together's (deferred as SQLite3 _remake_table needs it)
self.deferred_sql.extend(self._model_indexes_sql(model))
self.deferred_sql = list(set(self.deferred_sql))

# Make M2M tables
for field in model._meta.local_many_to_many:
Expand Down
28 changes: 28 additions & 0 deletions testapp/migrations/0001_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,41 @@ class Migration(migrations.Migration):
]

operations = [
migrations.CreateModel(
name='Author',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
],
),
migrations.CreateModel(
name='Editor',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
],
),
migrations.CreateModel(
name='Post',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(max_length=255, verbose_name='title')),
],
),
migrations.AddField(
model_name='post',
name='alt_editor',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='testapp.Editor'),
),
migrations.AddField(
model_name='post',
name='author',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='testapp.Author'),
),
migrations.AlterUniqueTogether(
name='post',
unique_together={('author', 'title', 'alt_editor')},
),
migrations.CreateModel(
name='Comment',
fields=[
Expand Down
16 changes: 16 additions & 0 deletions testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,24 @@
from django.utils import timezone


class Author(models.Model):
name = models.CharField(max_length=100)


class Editor(models.Model):
name = models.CharField(max_length=100)


class Post(models.Model):
title = models.CharField('title', max_length=255)
author = models.ForeignKey(Author, models.CASCADE)
# Optional secondary author
alt_editor = models.ForeignKey(Editor, models.SET_NULL, blank=True, null=True)

class Meta:
unique_together = (
('author', 'title', 'alt_editor'),
)

def __str__(self):
return self.title
Expand Down
Loading