Skip to content

Commit

Permalink
face recognition models as global
Browse files Browse the repository at this point in the history
  • Loading branch information
serengil committed Jun 23, 2021
1 parent 9f31012 commit e0809bf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
21 changes: 13 additions & 8 deletions deepface/DeepFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def build_model(model_name):
built deepface model
"""

global model_obj, model_label

models = {
'VGG-Face': VGGFace.loadModel,
'OpenFace': OpenFace.loadModel,
Expand All @@ -43,21 +45,24 @@ def build_model(model_name):
'DeepID': DeepID.loadModel,
'Dlib': DlibWrapper.loadModel,
'ArcFace': ArcFace.loadModel,

'Emotion': Emotion.loadModel,
'Age': Age.loadModel,
'Gender': Gender.loadModel,
'Race': Race.loadModel
}

model = models.get(model_name)
if not "model_obj" in globals() or model_label != model_name:

if model:
model = model()
#print('Using {} model backend'.format(model_name))
return model
else:
raise ValueError('Invalid model_name passed - {}'.format(model_name))
model_obj = models.get(model_name)

if model_obj:
model_obj = model_obj()
model_label = model_name
#print('Using {} model backend'.format(model_name))
else:
raise ValueError('Invalid model_name passed - {}'.format(model_name))

return model_obj

def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True, detector_backend = 'opencv', align = True):

Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,17 @@
passed_tests = 0; test_cases = 0

for model in models:
prebuilt_model = DeepFace.build_model(model)
print(model," is built")
#prebuilt_model = DeepFace.build_model(model)
#print(model," is built")
for metric in metrics:
for instance in dataset:
img1 = instance[0]
img2 = instance[1]
result = instance[2]

resp_obj = DeepFace.verify(img1, img2
, model_name = model, model = prebuilt_model
, model_name = model
#, model = prebuilt_model
, distance_metric = metric)

prediction = resp_obj["verified"]
Expand Down

0 comments on commit e0809bf

Please sign in to comment.