Skip to content

Commit e27be64

Browse files
committed
cosmetic edits to _get_encrypted_fields()
1 parent 7f0d08d commit e27be64

File tree

1 file changed

+24
-30
lines changed

1 file changed

+24
-30
lines changed

django_mongodb_backend/schema.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -477,24 +477,22 @@ def _create_collection(self, model):
477477
# Unencrypted path
478478
db.create_collection(db_table)
479479

480-
def _get_encrypted_fields(self, model, key_alt_name=None, path_prefix=None):
480+
def _get_encrypted_fields(self, model, key_alt_name_prefix=None, path_prefix=None):
481481
"""
482-
Recursively collect encryption schema data for only encrypted fields in a model.
483-
Returns None if no encrypted fields are found anywhere in the model hierarchy.
482+
Return the encrypted fields map for the given model. The "prefix"
483+
arguments are used when this method is called recursively on embedded
484+
models.
484485
"""
485486
connection = self.connection
486487
client = connection.connection
487-
fields = model._meta.fields
488-
key_alt_name = key_alt_name or model._meta.db_table
488+
key_alt_name_prefix = key_alt_name_prefix or model._meta.db_table
489489
path_prefix = path_prefix or ""
490-
491-
options = client._options
492-
auto_encryption_opts = options.auto_encryption_opts
493-
494-
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
495-
key_vault_collection = client[key_vault_db][key_vault_coll]
496-
497-
# Create partial unique index on keyAltNames
490+
auto_encryption_opts = client._options.auto_encryption_opts
491+
key_vault_db, key_vault_collection = auto_encryption_opts._key_vault_namespace.split(".", 1)
492+
key_vault_collection = client[key_vault_db][key_vault_collection]
493+
# Create partial unique index on keyAltNames.
494+
# TODO: find a better place for this. It only needs to run once for an
495+
# application's lifetime.
498496
key_vault_collection.create_index(
499497
"keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}
500498
)
@@ -508,46 +506,42 @@ def _get_encrypted_fields(self, model, key_alt_name=None, path_prefix=None):
508506
kms_provider = router.kms_provider(model)
509507
# Providing master_key raises an error for the local provider.
510508
master_key = kms_providers[kms_provider] if kms_provider != "local" else None
511-
client_encryption = self.connection.client_encryption
512-
509+
# Generate the encrypted fields map.
513510
field_list = []
514-
515-
for field in fields:
516-
new_key_alt_name = f"{key_alt_name}.{field.column}"
511+
for field in model._meta.fields:
512+
key_alt_name = f"{key_alt_name_prefix}.{field.column}"
517513
path = f"{path_prefix}.{field.column}" if path_prefix else field.column
518-
514+
# Check non-encrypted EmbeddedModelFields for encrypted fields.
519515
if isinstance(field, EmbeddedModelField) and not getattr(field, "encrypted", False):
520516
embedded_result = self._get_encrypted_fields(
521517
field.embedded_model,
522-
key_alt_name=new_key_alt_name,
518+
key_alt_name_prefix=key_alt_name,
523519
path_prefix=path,
524520
)
521+
# An EmbeddedModelField may not have any encrypted fields.
525522
if embedded_result:
526523
field_list.extend(embedded_result["fields"])
527524
continue
528-
525+
# Populate data for encrypted field.
529526
if getattr(field, "encrypted", False):
530-
bson_type = field.db_type(connection)
531-
data_key = key_vault_collection.find_one({"keyAltNames": new_key_alt_name})
527+
data_key = key_vault_collection.find_one({"keyAltNames": key_alt_name})
532528
if data_key:
533529
data_key = data_key["_id"]
534530
else:
535-
data_key = client_encryption.create_data_key(
531+
data_key = connection.client_encryption.create_data_key(
536532
kms_provider=kms_provider,
537-
key_alt_names=[new_key_alt_name],
533+
key_alt_names=[key_alt_name],
538534
master_key=master_key,
539535
)
540536
field_dict = {
541-
"bsonType": bson_type,
537+
"bsonType": field.db_type(connection),
542538
"path": path,
543539
"keyId": data_key,
544540
}
545-
queries = getattr(field, "queries", None)
546-
if queries:
541+
if queries := getattr(field, "queries", None):
547542
field_dict["queries"] = queries
548543
field_list.append(field_dict)
549-
550-
return {"fields": field_list} if field_list else None
544+
return {"fields": field_list}
551545

552546

553547
# GISSchemaEditor extends some SchemaEditor methods.

0 commit comments

Comments
 (0)