@@ -477,24 +477,22 @@ def _create_collection(self, model):
477477 # Unencrypted path
478478 db .create_collection (db_table )
479479
480- def _get_encrypted_fields (self , model , key_alt_name = None , path_prefix = None ):
480+ def _get_encrypted_fields (self , model , key_alt_name_prefix = None , path_prefix = None ):
481481 """
482- Recursively collect encryption schema data for only encrypted fields in a model.
483- Returns None if no encrypted fields are found anywhere in the model hierarchy.
482+ Return the encrypted fields map for the given model. The "prefix"
483+ arguments are used when this method is called recursively on embedded
484+ models.
484485 """
485486 connection = self .connection
486487 client = connection .connection
487- fields = model ._meta .fields
488- key_alt_name = key_alt_name or model ._meta .db_table
488+ key_alt_name_prefix = key_alt_name_prefix or model ._meta .db_table
489489 path_prefix = path_prefix or ""
490-
491- options = client ._options
492- auto_encryption_opts = options .auto_encryption_opts
493-
494- key_vault_db , key_vault_coll = auto_encryption_opts ._key_vault_namespace .split ("." , 1 )
495- key_vault_collection = client [key_vault_db ][key_vault_coll ]
496-
497- # Create partial unique index on keyAltNames
490+ auto_encryption_opts = client ._options .auto_encryption_opts
491+ key_vault_db , key_vault_collection = auto_encryption_opts ._key_vault_namespace .split ("." , 1 )
492+ key_vault_collection = client [key_vault_db ][key_vault_collection ]
493+ # Create partial unique index on keyAltNames.
494+ # TODO: find a better place for this. It only needs to run once for an
495+ # application's lifetime.
498496 key_vault_collection .create_index (
499497 "keyAltNames" , unique = True , partialFilterExpression = {"keyAltNames" : {"$exists" : True }}
500498 )
@@ -508,46 +506,42 @@ def _get_encrypted_fields(self, model, key_alt_name=None, path_prefix=None):
508506 kms_provider = router .kms_provider (model )
509507 # Providing master_key raises an error for the local provider.
510508 master_key = kms_providers [kms_provider ] if kms_provider != "local" else None
511- client_encryption = self .connection .client_encryption
512-
509+ # Generate the encrypted fields map.
513510 field_list = []
514-
515- for field in fields :
516- new_key_alt_name = f"{ key_alt_name } .{ field .column } "
511+ for field in model ._meta .fields :
512+ key_alt_name = f"{ key_alt_name_prefix } .{ field .column } "
517513 path = f"{ path_prefix } .{ field .column } " if path_prefix else field .column
518-
514+ # Check non-encrypted EmbeddedModelFields for encrypted fields.
519515 if isinstance (field , EmbeddedModelField ) and not getattr (field , "encrypted" , False ):
520516 embedded_result = self ._get_encrypted_fields (
521517 field .embedded_model ,
522- key_alt_name = new_key_alt_name ,
518+ key_alt_name_prefix = key_alt_name ,
523519 path_prefix = path ,
524520 )
521+ # An EmbeddedModelField may not have any encrypted fields.
525522 if embedded_result :
526523 field_list .extend (embedded_result ["fields" ])
527524 continue
528-
525+ # Populate data for encrypted field.
529526 if getattr (field , "encrypted" , False ):
530- bson_type = field .db_type (connection )
531- data_key = key_vault_collection .find_one ({"keyAltNames" : new_key_alt_name })
527+ data_key = key_vault_collection .find_one ({"keyAltNames" : key_alt_name })
532528 if data_key :
533529 data_key = data_key ["_id" ]
534530 else :
535- data_key = client_encryption .create_data_key (
531+ data_key = connection . client_encryption .create_data_key (
536532 kms_provider = kms_provider ,
537- key_alt_names = [new_key_alt_name ],
533+ key_alt_names = [key_alt_name ],
538534 master_key = master_key ,
539535 )
540536 field_dict = {
541- "bsonType" : bson_type ,
537+ "bsonType" : field . db_type ( connection ) ,
542538 "path" : path ,
543539 "keyId" : data_key ,
544540 }
545- queries = getattr (field , "queries" , None )
546- if queries :
541+ if queries := getattr (field , "queries" , None ):
547542 field_dict ["queries" ] = queries
548543 field_list .append (field_dict )
549-
550- return {"fields" : field_list } if field_list else None
544+ return {"fields" : field_list }
551545
552546
553547# GISSchemaEditor extends some SchemaEditor methods.
0 commit comments