Skip to content

Commit

Permalink
Merge pull request #1327 from serengil/feat-task-3108-exception-handl…
Browse files Browse the repository at this point in the history
…ing-for-loading-weights

load weights is done in a common function
  • Loading branch information
serengil authored Aug 31, 2024
2 parents a3088ac + 8b86f36 commit 46fe4a8
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 13 deletions.
29 changes: 28 additions & 1 deletion deepface/commons/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
import gdown

# project dependencies
from deepface.commons import folder_utils
from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger

tf_version = package_utils.get_tf_major_version()
if tf_version == 1:
from keras.models import Sequential
else:
from tensorflow.keras.models import Sequential

logger = Logger()


Expand Down Expand Up @@ -63,3 +69,24 @@ def download_weights_if_necessary(
logger.info(f"{target_file}.bz2 unzipped")

return target_file


def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
"""
Load pre-trained weights for a given model
Args:
model (keras.models.Sequential): pre-built model
weight_file (str): exact path of pre-trained weights
Returns:
model (keras.models.Sequential): pre-built model with
updated weights
"""
try:
model.load_weights(weight_file)
except Exception as err:
raise ValueError(
f"Exception while loading pre-trained weights from {weight_file}."
"Possible reason is broken file during downloading weights."
"You may consider to delete it manually."
) from err
return model
5 changes: 4 additions & 1 deletion deepface/models/demography/Age.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def load_model(
weight_file = weight_utils.download_weights_if_necessary(
file_name="age_model_weights.h5", source_url=url
)
age_model.load_weights(weight_file)

age_model = weight_utils.load_model_weights(
model=age_model, weight_file=weight_file
)

return age_model

Expand Down
4 changes: 3 additions & 1 deletion deepface/models/demography/Emotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def load_model(
file_name="facial_expression_model_weights.h5", source_url=url
)

model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)

return model
4 changes: 3 additions & 1 deletion deepface/models/demography/Gender.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def load_model(
file_name="gender_model_weights.h5", source_url=url
)

gender_model.load_weights(weight_file)
gender_model = weight_utils.load_model_weights(
model=gender_model, weight_file=weight_file
)

return gender_model
4 changes: 3 additions & 1 deletion deepface/models/demography/Race.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def load_model(
file_name="race_model_single_batch.h5", source_url=url
)

race_model.load_weights(weight_file)
race_model = weight_utils.load_model_weights(
model=race_model, weight_file=weight_file
)

return race_model
2 changes: 1 addition & 1 deletion deepface/models/facial_recognition/ArcFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def load_model(
file_name="arcface_weights.h5", source_url=url
)

model.load_weights(weight_file)
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
# ---------------------------------------

return model
Expand Down
4 changes: 3 additions & 1 deletion deepface/models/facial_recognition/DeepID.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def load_model(
file_name="deepid_keras_weights.h5", source_url=url
)

model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)

return model
8 changes: 6 additions & 2 deletions deepface/models/facial_recognition/Facenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,7 +1668,9 @@ def load_facenet128d_model(
weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)

return model

Expand All @@ -1687,6 +1689,8 @@ def load_facenet512d_model(
weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet512_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)

return model
2 changes: 1 addition & 1 deletion deepface/models/facial_recognition/FbDeepFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_model(
file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip"
)

base_model.load_weights(weight_file)
base_model = weight_utils.load_model_weights(model=base_model, weight_file=weight_file)

# drop F8 and D0. F7 is the representation layer.
deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output)
Expand Down
4 changes: 3 additions & 1 deletion deepface/models/facial_recognition/GhostFaceNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def load_model():
file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS
)

model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)

return model

Expand Down
4 changes: 3 additions & 1 deletion deepface/models/facial_recognition/OpenFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ def load_model(
file_name="openface_weights.h5", source_url=url
)

model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)

# -----------------------------------

Expand Down
4 changes: 3 additions & 1 deletion deepface/models/facial_recognition/VGGFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def load_model(
file_name="vgg_face_weights.h5", source_url=url
)

model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)

# 2622d dimensional model
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
Expand Down

0 comments on commit 46fe4a8

Please sign in to comment.