Skip to content

Commit

Permalink
Centralize cv2 image reading and handle bad filenames
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Jun 2, 2019
1 parent a6a97f7 commit a329452
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 32 deletions.
7 changes: 3 additions & 4 deletions lib/face_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import logging

import cv2

from lib.faces_detect import DetectedFace
from lib.logger import get_loglevel
from lib.vgg_face import VGGFace
from lib.utils import cv2_read_img
from plugins.extract.pipeline import Extractor

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -41,10 +40,10 @@ def load_images(reference_file_paths, nreference_file_paths):
""" Load the images """
retval = dict()
for fpath in reference_file_paths:
retval[fpath] = {"image": cv2.imread(fpath), # pylint: disable=no-member
retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
"type": "filter"}
for fpath in nreference_file_paths:
retval[fpath] = {"image": cv2.imread(fpath), # pylint: disable=no-member
retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
"type": "nfilter"}
logger.debug("Loaded filter images: %s", {k: v["type"] for k, v in retval.items()})
return retval
Expand Down
7 changes: 2 additions & 5 deletions lib/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lib.multithreading import FixedProducerDispatcher
from lib.queue_manager import queue_manager
from lib.umeyama import umeyama
from lib.utils import cv2_read_img

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -145,11 +146,7 @@ def process_face(self, filename, side, is_timelapse):
""" Load an image and perform transformation and warping """
logger.trace("Process face: (filename: '%s', side: '%s', is_timelapse: %s)",
filename, side, is_timelapse)
try:
image = cv2.imread(filename) # pylint: disable=no-member
except TypeError:
raise Exception("Error while reading image", filename)

image = cv2_read_img(filename, raise_error=True)
if self.mask_class or self.training_opts["warp_to_landmarks"]:
src_pts = self.get_landmarks(filename, image, side)
if self.mask_class:
Expand Down
37 changes: 36 additions & 1 deletion lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,44 @@ def get_image_paths(directory):
return dir_contents


def cv2_read_img(filename, raise_error=False):
""" Read an image with cv2 and check that an image was actually loaded.
Logs an error if the image returned is None. or an error has occured.
Pass raise_error=True if error should be raised """
logger.trace("Requested image: '%s'", filename)
success = True
image = None
try:
image = cv2.imread(filename) # pylint: disable=no-member
if image is None:
raise ValueError
except TypeError:
success = False
msg = "Error while reading image (TypeError): '{}'".format(filename)
logger.error(msg)
if raise_error:
raise Exception(msg)
except ValueError:
success = False
msg = ("Error while reading image. This is most likely caused by special characters in "
"the filename: '{}'".format(filename))
logger.error(msg)
if raise_error:
raise Exception(msg)
except Exception as err: # pylint: disable=broad-except
success = False
msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err))
logger.error(msg)
if raise_error:
raise Exception(msg)
logger.trace("Loaded image: '%s'. Success: %s", filename, success)
return image


def hash_image_file(filename):
""" Return an image file's sha1 hash """
img = cv2.imread(filename) # pylint: disable=no-member
img = cv2_read_img(filename, raise_error=True)
img_hash = sha1(img).hexdigest()
logger.trace("filename: '%s', hash: %s", filename, img_hash)
return img_hash
Expand Down
11 changes: 4 additions & 7 deletions scripts/fsmedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lib.aligner import Extract as AlignerExtract
from lib.alignments import Alignments as AlignmentsBase
from lib.face_filter import FaceFilter as FilterFunc
from lib.utils import (camel_case_split, get_folder, get_image_paths,
from lib.utils import (camel_case_split, cv2_read_img, get_folder, get_image_paths,
set_system_verbosity, _video_extensions)

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -184,11 +184,8 @@ def load_disk_frames(self):
""" Load frames from disk """
logger.debug("Input is separate Frames. Loading images")
for filename in self.input_images:
logger.trace("Loading image: '%s'", filename)
try:
image = cv2.imread(filename) # pylint: disable=no-member
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to load image '%s'. Original Error: %s", filename, err)
image = cv2_read_img(filename, raise_error=False)
if image is None:
continue
yield filename, image

Expand Down Expand Up @@ -221,7 +218,7 @@ def load_one_image(self, filename):
logger.trace("Extracted frame_no %s from filename '%s'", frame_no, filename)
retval = self.load_one_video_frame(int(frame_no))
else:
retval = cv2.imread(filename) # pylint: disable=no-member
retval = cv2_read_img(filename, raise_error=True)
return retval

def load_one_video_frame(self, frame_no):
Expand Down
4 changes: 2 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lib.keypress import KBHit
from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.utils import (get_folder, get_image_paths, set_system_verbosity)
from lib.utils import cv2_read_img, get_folder, get_image_paths, set_system_verbosity
from plugins.plugin_loader import PluginLoader

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -170,7 +170,7 @@ def load_model(self):
@property
def image_size(self):
""" Get the training set image size for storing in model data """
image = cv2.imread(self.images["a"][0]) # pylint: disable=no-member
image = cv2_read_img(self.images["a"][0], raise_error=True)
size = image.shape[0]
logger.debug("Training image size: %s", size)
return size
Expand Down
5 changes: 3 additions & 2 deletions tools/lib_alignments/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from lib.alignments import Alignments
from lib.faces_detect import DetectedFace
from lib.utils import _image_extensions, _video_extensions, hash_image_file, hash_encode_image
from lib.utils import (_image_extensions, _video_extensions, cv2_read_img, hash_image_file,
hash_encode_image)

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -160,7 +161,7 @@ def load_image(self, filename):
else:
src = os.path.join(self.folder, filename)
logger.trace("Loading image: '%s'", src)
image = cv2.imread(src) # pylint: disable=no-member
image = cv2_read_img(src, raise_error=True)
return image

def load_video_frame(self, filename):
Expand Down
21 changes: 10 additions & 11 deletions tools/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lib.faces_detect import DetectedFace
from lib.multithreading import SpawnProcess
from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import cv2_read_img
from lib.vgg_face import VGGFace
from plugins.plugin_loader import PluginLoader

Expand Down Expand Up @@ -131,7 +132,7 @@ def alignment_dict(image):
@staticmethod
def get_landmarks(filename):
""" Extract the face from a frame (If not alignments file found) """
image = cv2.imread(filename)
image = cv2_read_img(filename, raise_error=True)
queue_manager.get_queue("in").put(Sort.alignment_dict(image))
face = queue_manager.get_queue("out").get()
landmarks = face["landmarks"][0]
Expand Down Expand Up @@ -184,7 +185,7 @@ def sort_face(self):
logger.info("Sorting by face similarity...")

images = np.array(self.find_images(input_dir))
preds = np.array([self.vgg_face.predict(cv2.imread(img))
preds = np.array([self.vgg_face.predict(cv2_read_img(img, raise_error=True))
for img in tqdm(images, desc="loading", file=sys.stdout)])
logger.info("Sorting. Depending on ths size of your dataset, this may take a few "
"minutes...")
Expand Down Expand Up @@ -287,7 +288,7 @@ def sort_hist(self):
logger.info("Sorting by histogram similarity...")

img_list = [
[img, cv2.calcHist([cv2.imread(img)], [0], None, [256], [0, 256])]
[img, cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])]
for img in
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
]
Expand Down Expand Up @@ -317,7 +318,7 @@ def sort_hist_dissim(self):

img_list = [
[img,
cv2.calcHist([cv2.imread(img)], [0], None, [256], [0, 256]), 0]
cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256]), 0]
for img in
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
]
Expand Down Expand Up @@ -571,7 +572,7 @@ def reload_images(self, group_method, img_list):
input_dir = self.args.input_dir
logger.info("Preparing to group...")
if group_method == 'group_blur':
temp_list = [[img, self.estimate_blur(cv2.imread(img))]
temp_list = [[img, self.estimate_blur(cv2_read_img(img, raise_error=True))]
for img in
tqdm(self.find_images(input_dir),
desc="Reloading",
Expand Down Expand Up @@ -599,7 +600,7 @@ def reload_images(self, group_method, img_list):
elif group_method == 'group_hist':
temp_list = [
[img,
cv2.calcHist([cv2.imread(img)], [0], None, [256], [0, 256])]
cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])]
for img in
tqdm(self.find_images(input_dir),
desc="Reloading",
Expand Down Expand Up @@ -652,12 +653,10 @@ def find_images(input_dir):
@staticmethod
def estimate_blur(image_file):
"""
Estimate the amount of blur an image has
with the variance of the Laplacian.
Normalize by pixel number to offset the effect
of image size on pixel gradients & variance
Estimate the amount of blur an image has with the variance of the Laplacian.
Normalize by pixel number to offset the effect of image size on pixel gradients & variance
"""
image = cv2.imread(image_file)
image = cv2_read_img(image_file, raise_error=True)
if image.ndim == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blur_map = cv2.Laplacian(image, cv2.CV_32F)
Expand Down

0 comments on commit a329452

Please sign in to comment.