Skip to content

Commit

Permalink
customize broken weight file test
Browse files Browse the repository at this point in the history
  • Loading branch information
serengil committed Aug 31, 2024
1 parent e834660 commit 005280f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 36 deletions.
46 changes: 46 additions & 0 deletions tests/test_commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# built-in dependencies
import os
import pytest

# project dependencies
from deepface.commons import folder_utils, weight_utils, package_utils
from deepface.commons.logger import Logger

logger = Logger()

tf_version = package_utils.get_tf_major_version()

if tf_version == 1:
from keras.models import Sequential
from keras.layers import (
Dropout,
Dense,
)
else:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
Dropout,
Dense,
)


def test_loading_broken_weights():
home = folder_utils.get_deepface_home()
weight_file = os.path.join(home, ".deepface/weights/vgg_face_weights.h5")

# construct a dummy model
model = Sequential()

# Add layers to the model
model.add(
Dense(units=64, activation="relu", input_shape=(100,))
) # Input layer with 100 features
model.add(Dropout(0.5)) # Dropout layer to prevent overfitting
model.add(Dense(units=32, activation="relu")) # Hidden layer
model.add(Dense(units=10, activation="softmax")) # Output layer with 10 classes

# vgg's weights cannot be loaded to this model
with pytest.raises(ValueError, match="Exception while loading pre-trained weights from"):
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)

logger.info("✅ test loading broken weight file is done")
36 changes: 0 additions & 36 deletions tests/test_verify.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# built-in dependencies
import os

# 3rd party dependencies
import pytest
import cv2

# project dependencies
from deepface import DeepFace
from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger

logger = Logger()
Expand Down Expand Up @@ -192,35 +188,3 @@ def test_verify_for_nested_embeddings():
_ = DeepFace.verify(img1_path=img1_embeddings, img2_path=img2_path)

logger.info("✅ test verify for nested embeddings is done")


def test_verify_for_broken_weights():
home = folder_utils.get_deepface_home()

# we are not performing anything with model deepid

weights_file = os.path.join(home, ".deepface/weights/deepid_keras_weights.h5")
backup_file = os.path.join(home, ".deepface/weights/deepid_keras_weights_backup.h5")

restore = False
# backup original weight file
if os.path.exists(weights_file) is True:
os.rename(weights_file, backup_file)
restore = True

# Create a dummy vgg_face_weights.h5 file
with open(weights_file, "w", encoding="UTF-8") as f:
f.write("dummy content")

with pytest.raises(ValueError, match="Exception while loading pre-trained weights from"):
_ = DeepFace.verify(
img1_path="dataset/img1.jpg",
img2_path="dataset/img2.jpg",
model_name="DeepId",
)

if restore:
os.remove(weights_file)
os.rename(backup_file, weights_file)

logger.info("✅ test verify for broken weight file is done")

0 comments on commit 005280f

Please sign in to comment.