5
5
BaseDatabaseSchemaEditor , logger , _is_relevant_relation , _related_non_m2m_objects ,
6
6
)
7
7
from django .db .backends .ddl_references import (
8
- Statement ,
8
+ Columns , IndexName , Statement as DjStatement , Table ,
9
9
)
10
10
from django .db .models import Index
11
11
from django .db .models .fields import AutoField , BigAutoField
12
12
from django .db .transaction import TransactionManagementError
13
13
from django .utils .encoding import force_text
14
14
15
15
16
+ class Statement (DjStatement ):
17
+ def __hash__ (self ):
18
+ return hash ((self .template , str (self .parts ['name' ])))
19
+
20
+ def __eq__ (self , other ):
21
+ return self .template == other .template and str (self .parts ['name' ]) == str (other .parts ['name' ])
22
+
23
+
16
24
class DatabaseSchemaEditor (BaseDatabaseSchemaEditor ):
17
25
18
26
_sql_check_constraint = " CONSTRAINT %(name)s CHECK (%(check)s)"
@@ -123,6 +131,105 @@ def _alter_column_type_sql(self, model, old_field, new_field, new_type):
123
131
new_type = self ._set_field_new_type_null_status (old_field , new_type )
124
132
return super ()._alter_column_type_sql (model , old_field , new_field , new_type )
125
133
134
+ def alter_unique_together (self , model , old_unique_together , new_unique_together ):
135
+ """
136
+ Deal with a model changing its unique_together. The input
137
+ unique_togethers must be doubly-nested, not the single-nested
138
+ ["foo", "bar"] format.
139
+ """
140
+ olds = {tuple (fields ) for fields in old_unique_together }
141
+ news = {tuple (fields ) for fields in new_unique_together }
142
+ # Deleted uniques
143
+ for fields in olds .difference (news ):
144
+ self ._delete_composed_index (model , fields , {'unique' : True }, self .sql_delete_index )
145
+ # Created uniques
146
+ for fields in news .difference (olds ):
147
+ columns = [model ._meta .get_field (field ).column for field in fields ]
148
+ condition = ' AND ' .join (["[%s] IS NOT NULL" % col for col in columns ])
149
+ sql = self ._create_unique_sql (model , columns , condition = condition )
150
+ self .execute (sql )
151
+
152
+ def _model_indexes_sql (self , model ):
153
+ """
154
+ Return a list of all index SQL statements (field indexes,
155
+ index_together, Meta.indexes) for the specified model.
156
+ """
157
+ if not model ._meta .managed or model ._meta .proxy or model ._meta .swapped :
158
+ return []
159
+ output = []
160
+ for field in model ._meta .local_fields :
161
+ output .extend (self ._field_indexes_sql (model , field ))
162
+
163
+ for field_names in model ._meta .index_together :
164
+ fields = [model ._meta .get_field (field ) for field in field_names ]
165
+ output .append (self ._create_index_sql (model , fields , suffix = "_idx" ))
166
+
167
+ for field_names in model ._meta .unique_together :
168
+ columns = [model ._meta .get_field (field ).column for field in field_names ]
169
+ condition = ' AND ' .join (["[%s] IS NOT NULL" % col for col in columns ])
170
+ sql = self ._create_unique_sql (model , columns , condition = condition )
171
+ output .append (sql )
172
+
173
+ for index in model ._meta .indexes :
174
+ output .append (index .create_sql (model , self ))
175
+ return output
176
+
177
+ def _alter_many_to_many (self , model , old_field , new_field , strict ):
178
+ """Alter M2Ms to repoint their to= endpoints."""
179
+
180
+ for idx in self ._constraint_names (old_field .remote_field .through , index = True , unique = True ):
181
+ self .execute (self .sql_delete_index % {'name' : idx , 'table' : old_field .remote_field .through ._meta .db_table })
182
+
183
+ return super ()._alter_many_to_many (model , old_field , new_field , strict )
184
+
185
+ def _db_table_constraint_names (self , db_table , column_names = None , unique = None ,
186
+ primary_key = None , index = None , foreign_key = None ,
187
+ check = None , type_ = None , exclude = None ):
188
+ """Return all constraint names matching the columns and conditions."""
189
+ if column_names is not None :
190
+ column_names = [
191
+ self .connection .introspection .identifier_converter (name )
192
+ for name in column_names
193
+ ]
194
+ with self .connection .cursor () as cursor :
195
+ constraints = self .connection .introspection .get_constraints (cursor , db_table )
196
+ result = []
197
+ for name , infodict in constraints .items ():
198
+ if column_names is None or column_names == infodict ['columns' ]:
199
+ if unique is not None and infodict ['unique' ] != unique :
200
+ continue
201
+ if primary_key is not None and infodict ['primary_key' ] != primary_key :
202
+ continue
203
+ if index is not None and infodict ['index' ] != index :
204
+ continue
205
+ if check is not None and infodict ['check' ] != check :
206
+ continue
207
+ if foreign_key is not None and not infodict ['foreign_key' ]:
208
+ continue
209
+ if type_ is not None and infodict ['type' ] != type_ :
210
+ continue
211
+ if not exclude or name not in exclude :
212
+ result .append (name )
213
+ return result
214
+
215
+ def _db_table_delete_constraint_sql (self , template , db_table , name ):
216
+ return Statement (
217
+ template ,
218
+ table = Table (db_table , self .quote_name ),
219
+ name = self .quote_name (name ),
220
+ )
221
+
222
+ def alter_db_table (self , model , old_db_table , new_db_table ):
223
+ index_names = self ._db_table_constraint_names (old_db_table , index = True )
224
+ for index_name in index_names :
225
+ self .execute (self ._db_table_delete_constraint_sql (self .sql_delete_index , old_db_table , index_name ))
226
+
227
+ index_names = self ._db_table_constraint_names (new_db_table , index = True )
228
+ for index_name in index_names :
229
+ self .execute (self ._db_table_delete_constraint_sql (self .sql_delete_index , new_db_table , index_name ))
230
+
231
+ return super ().alter_db_table (model , old_db_table , new_db_table )
232
+
126
233
def _alter_field (self , model , old_field , new_field , old_type , new_type ,
127
234
old_db_params , new_db_params , strict = False ):
128
235
"""Actually perform a "physical" (non-ManyToMany) field update."""
@@ -224,11 +331,15 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
224
331
self .execute (self ._delete_constraint_sql (self .sql_delete_check , model , constraint_name ))
225
332
# Have they renamed the column?
226
333
if old_field .column != new_field .column :
334
+ # remove old indices
335
+ self ._delete_indexes (model , old_field , new_field )
336
+
227
337
self .execute (self ._rename_field_sql (model ._meta .db_table , old_field , new_field , new_type ))
228
338
# Rename all references to the renamed column.
229
339
for sql in self .deferred_sql :
230
- if isinstance (sql , Statement ):
340
+ if isinstance (sql , DjStatement ):
231
341
sql .rename_column_references (model ._meta .db_table , old_field .column , new_field .column )
342
+
232
343
# Next, start accumulating actions to do
233
344
actions = []
234
345
null_actions = []
@@ -286,6 +397,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
286
397
actions = [(", " .join (sql ), sum (params , []))]
287
398
# Apply those actions
288
399
for sql , params in actions :
400
+ self ._delete_indexes (model , old_field , new_field )
289
401
self .execute (
290
402
self .sql_alter_column % {
291
403
"table" : self .quote_name (model ._meta .db_table ),
@@ -438,6 +550,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
438
550
"changes" : changes_sql ,
439
551
}
440
552
self .execute (sql , params )
553
+
441
554
# Reset connection if required
442
555
if self .connection .features .connection_persists_old_columns :
443
556
self .connection .close ()
@@ -446,11 +559,15 @@ def _delete_indexes(self, model, old_field, new_field):
446
559
index_columns = []
447
560
if old_field .db_index and new_field .db_index :
448
561
index_columns .append ([old_field .column ])
449
- else :
450
- for fields in model ._meta .index_together :
451
- columns = [model ._meta .get_field (field ).column for field in fields ]
452
- if old_field .column in columns :
453
- index_columns .append (columns )
562
+ for fields in model ._meta .index_together :
563
+ columns = [model ._meta .get_field (field ).column for field in fields ]
564
+ if old_field .column in columns :
565
+ index_columns .append (columns )
566
+
567
+ for fields in model ._meta .unique_together :
568
+ columns = [model ._meta .get_field (field ).column for field in fields ]
569
+ if old_field .column in columns :
570
+ index_columns .append (columns )
454
571
if index_columns :
455
572
for columns in index_columns :
456
573
index_names = self ._constraint_names (model , columns , index = True )
@@ -461,11 +578,6 @@ def _delete_unique_constraints(self, model, old_field, new_field, strict=False):
461
578
unique_columns = []
462
579
if old_field .unique and new_field .unique :
463
580
unique_columns .append ([old_field .column ])
464
- else :
465
- for fields in model ._meta .unique_together :
466
- columns = [model ._meta .get_field (field ).column for field in fields ]
467
- if old_field .column in columns :
468
- unique_columns .append (columns )
469
581
if unique_columns :
470
582
for columns in unique_columns :
471
583
constraint_names = self ._constraint_names (model , columns , unique = True )
@@ -544,6 +656,61 @@ def add_field(self, model, field):
544
656
if self .connection .features .connection_persists_old_columns :
545
657
self .connection .close ()
546
658
659
+ def _create_unique_sql (self , model , columns , name = None , condition = None ):
660
+ def create_unique_name (* args , ** kwargs ):
661
+ return self .quote_name (self ._create_index_name (* args , ** kwargs ))
662
+
663
+ table = Table (model ._meta .db_table , self .quote_name )
664
+ if name is None :
665
+ name = IndexName (model ._meta .db_table , columns , '_uniq' , create_unique_name )
666
+ else :
667
+ name = self .quote_name (name )
668
+ columns = Columns (table , columns , self .quote_name )
669
+ if condition :
670
+ return Statement (
671
+ self .sql_create_unique_index ,
672
+ table = table ,
673
+ name = name ,
674
+ columns = columns ,
675
+ condition = ' WHERE ' + condition ,
676
+ ) if self .connection .features .supports_partial_indexes else None
677
+ else :
678
+ return Statement (
679
+ self .sql_create_unique ,
680
+ table = table ,
681
+ name = name ,
682
+ columns = columns ,
683
+ )
684
+
685
+ def _create_index_sql (self , model , fields , * , name = None , suffix = '' , using = '' ,
686
+ db_tablespace = None , col_suffixes = (), sql = None , opclasses = (),
687
+ condition = None ):
688
+ """
689
+ Return the SQL statement to create the index for one or several fields.
690
+ `sql` can be specified if the syntax differs from the standard (GIS
691
+ indexes, ...).
692
+ """
693
+ tablespace_sql = self ._get_index_tablespace_sql (model , fields , db_tablespace = db_tablespace )
694
+ columns = [field .column for field in fields ]
695
+ sql_create_index = sql or self .sql_create_index
696
+ table = model ._meta .db_table
697
+
698
+ def create_index_name (* args , ** kwargs ):
699
+ nonlocal name
700
+ if name is None :
701
+ name = self ._create_index_name (* args , ** kwargs )
702
+ return self .quote_name (name )
703
+
704
+ return Statement (
705
+ sql_create_index ,
706
+ table = Table (table , self .quote_name ),
707
+ name = IndexName (table , columns , suffix , create_index_name ),
708
+ using = using ,
709
+ columns = self ._index_columns (table , columns , col_suffixes , opclasses ),
710
+ extra = tablespace_sql ,
711
+ condition = (' WHERE ' + condition ) if condition else '' ,
712
+ )
713
+
547
714
def create_model (self , model ):
548
715
"""
549
716
Takes a model and creates a table for it in the database.
@@ -605,7 +772,9 @@ def create_model(self, model):
605
772
# created afterwards, like geometry fields with some backends)
606
773
for fields in model ._meta .unique_together :
607
774
columns = [model ._meta .get_field (field ).column for field in fields ]
608
- self .deferred_sql .append (self ._create_unique_sql (model , columns ))
775
+ condition = ' AND ' .join (["[%s] IS NOT NULL" % col for col in columns ])
776
+ self .deferred_sql .append (self ._create_unique_sql (model , columns , condition = condition ))
777
+
609
778
# Make the table
610
779
sql = self .sql_create_table % {
611
780
"table" : self .quote_name (model ._meta .db_table ),
@@ -620,6 +789,7 @@ def create_model(self, model):
620
789
621
790
# Add any field index and index_together's (deferred as SQLite3 _remake_table needs it)
622
791
self .deferred_sql .extend (self ._model_indexes_sql (model ))
792
+ self .deferred_sql = list (set (self .deferred_sql ))
623
793
624
794
# Make M2M tables
625
795
for field in model ._meta .local_many_to_many :
0 commit comments