Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model_refactor (#571) #572

Merged
merged 157 commits into from
Feb 9, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
157 commits
Select commit Hold shift + click to select a range
56526f8
model_refactor (#571)
torzdf Jan 2, 2019
f0a044f
Add mask to dfaker (#573)
torzdf Jan 2, 2019
5d21061
dfl mask. Make masks selectable in config (#575)
torzdf Jan 3, 2019
d68c001
remove gan_v2_2
torzdf Jan 3, 2019
860403a
Creating Input Size config for models
kvrooman Jan 3, 2019
c248d4b
Add mask loss options to config
torzdf Jan 3, 2019
24e0f9a
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 3, 2019
3a3b190
MTCNN options to config.ini. Remove GAN config. Update USAGE.md
torzdf Jan 4, 2019
e5b52d2
Add sliders for numerical values in GUI
torzdf Jan 4, 2019
a75281d
Add config plugins menu to gui. Validate config
torzdf Jan 8, 2019
fae1c92
Only backup model if loss has dropped. Get training working again
torzdf Jan 8, 2019
70e93d6
bugfixes
torzdf Jan 9, 2019
be6dbf0
Standardise loss printing
torzdf Jan 9, 2019
6acd3a0
GUI idle cpu fixes. Graph loss fix.
torzdf Jan 10, 2019
07f9da2
mutli-gpu logging bugfix
torzdf Jan 10, 2019
a5d9bb0
Merge branch 'staging' into train_refactor
torzdf Jan 10, 2019
15d412d
backup state file
torzdf Jan 11, 2019
2400fd5
Crash protection: Only backup if both total losses have dropped
torzdf Jan 11, 2019
3817f16
Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes)
torzdf Jan 11, 2019
a54a39a
Merge branch 'staging' into train_refactor
torzdf Jan 11, 2019
e1596ce
Load and save model structure with weights
torzdf Jan 13, 2019
4b77330
Slight code update
torzdf Jan 13, 2019
4c8f4ba
Improve config loader. Add subpixel opt to all models. Config to state
torzdf Jan 13, 2019
973f09e
Show samples... wrong input
kvrooman Jan 13, 2019
f11ed11
Remove AE topology. Add input/output shapes to State
torzdf Jan 14, 2019
ffb97c0
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 14, 2019
4a5db0e
Merge branch 'staging' into train_refactor
torzdf Jan 14, 2019
1127324
Port original_villain (birb/VillainGuy) model to faceswap
torzdf Jan 15, 2019
d09a2db
Merge branch 'staging' into train_refactor
torzdf Jan 15, 2019
1514429
Add plugin info to GUI config pages
torzdf Jan 15, 2019
1d1061a
Load input shape from state. IAE Config options.
torzdf Jan 15, 2019
6846be4
Merge branch 'staging' into train_refactor
torzdf Jan 16, 2019
61d8147
Fix transform_kwargs.
torzdf Jan 16, 2019
7d18589
Suppress keras userwarnings.
torzdf Jan 16, 2019
2052a2f
Consolidation of converters & refactor (#574)
kvrooman Jan 16, 2019
6f11519
Backwards compatibility fix for models
torzdf Jan 16, 2019
f2e0089
Convert:
torzdf Jan 17, 2019
277007b
mask fix
kvrooman Jan 17, 2019
b3b0269
convert fixes
kvrooman Jan 17, 2019
45ef5df
Update cli.py
kvrooman Jan 17, 2019
5e8a7cd
default for blur
kvrooman Jan 17, 2019
c55f5ce
Update masked.py
kvrooman Jan 17, 2019
0d98cda
added preliminary low_mem version of OriginalHighRes model plugin
Jan 17, 2019
dde403f
Code cleanup, minor fixes
torzdf Jan 17, 2019
352eec9
Update masked.py
kvrooman Jan 17, 2019
14064be
Update masked.py
kvrooman Jan 17, 2019
9bbdf9c
Add dfl mask to convert
torzdf Jan 17, 2019
4e36051
Fix merge conflict
torzdf Jan 17, 2019
fa95e70
Merge branch 'staging' into train_refactor
torzdf Jan 18, 2019
794ce14
histogram fix & seamless location
kvrooman Jan 18, 2019
5c5bb13
update
kvrooman Jan 18, 2019
d194ad5
revert
kvrooman Jan 18, 2019
dcf3bbe
bugfix: Load actual configuration in gui
torzdf Jan 18, 2019
4bef884
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 18, 2019
0baafd1
Standardize nn_blocks
torzdf Jan 18, 2019
8ed9496
Merge branch 'staging' into train_refactor
torzdf Jan 18, 2019
9364728
Update cli.py
kvrooman Jan 18, 2019
f83b1f3
Minor code amends
torzdf Jan 18, 2019
5097f67
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 18, 2019
f965423
Fix Original HiRes model
torzdf Jan 21, 2019
9df5a36
Add masks to preview output for mask trainers
torzdf Jan 21, 2019
d6b0778
Masked trainers converter support
torzdf Jan 21, 2019
0a3397d
convert bugfix
torzdf Jan 21, 2019
2a4a7a0
Bugfix: Converter for masked (dfl/dfaker) trainers
torzdf Jan 21, 2019
32bf1a0
Additional Losses (#592)
kvrooman Jan 21, 2019
34e8fee
default initializer = He instead of Glorot (#588)
kvrooman Jan 21, 2019
1115953
Allow kernel_initializer to be overridable
torzdf Jan 21, 2019
7b8335c
Add ICNR Initializer option for upscale on all models.
torzdf Jan 22, 2019
8b554fc
Hopefully fixes RSoDs with original-highres model plugin
Jan 22, 2019
31d1105
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
Jan 22, 2019
1df779d
remove debug line
torzdf Jan 22, 2019
18003e1
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 22, 2019
ccdb09d
Original-HighRes model plugin Red Screen of Death fix, take #2
Jan 23, 2019
a06665d
Move global options to _base. Rename Villain model
torzdf Jan 23, 2019
268ccf2
clipnorm and res block biases
kvrooman Jan 23, 2019
62f2b6f
scale the end of res block
kvrooman Jan 23, 2019
c9d6698
res block
kvrooman Jan 23, 2019
5bb2ef5
dfaker pre-activation res
kvrooman Jan 23, 2019
5d3815f
OHRES pre-activation
kvrooman Jan 23, 2019
b9e6040
villain pre-activation
kvrooman Jan 23, 2019
86fba3a
tabs/space in nn_blocks
kvrooman Jan 24, 2019
a5f7311
fix for histogram with mask all set to zero
kvrooman Jan 24, 2019
46fa813
fix to prevent two networks with same name
kvrooman Jan 24, 2019
179fdc7
Merge branch 'staging' into train_refactor
torzdf Jan 24, 2019
a96c588
GUI: Wider tooltips. Improve TQDM capture
torzdf Jan 24, 2019
e50e525
Fix regex bug
torzdf Jan 24, 2019
3b95517
Convert padding=48 to ratio of image size
torzdf Jan 24, 2019
4b2ee6f
Add size option to alignments tool extract
torzdf Jan 25, 2019
da74ccd
Pass through training image size to convert from model
torzdf Jan 25, 2019
a4d2653
Convert: Pull training coverage from model
torzdf Jan 25, 2019
3acfcab
convert: coverage, blur and erode to percent
torzdf Jan 25, 2019
6ab46ed
simplify matrix scaling
kvrooman Jan 26, 2019
ce84cf9
ordering of sliders in train
kvrooman Jan 26, 2019
92725d3
Add matrix scaling to utils. Use interpolation in lib.aligner transform
torzdf Jan 26, 2019
75867c8
masked.py Import get_matrix_scaling from utils
torzdf Jan 26, 2019
ea26992
fix circular import
torzdf Jan 26, 2019
3b63ecb
Update masked.py
kvrooman Jan 26, 2019
cff5ae2
quick fix for matrix scaling
kvrooman Jan 26, 2019
e290c09
testing thus for now
kvrooman Jan 26, 2019
dff452a
Merge branch 'staging' into train_refactor
torzdf Jan 26, 2019
6e2a91e
tqdm regex capture bugfix
torzdf Jan 26, 2019
873b0e5
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 26, 2019
8cc9249
Minor ammends
torzdf Jan 26, 2019
39692bc
blur size cleanup
kvrooman Jan 26, 2019
4afb6a0
Remove coverage option from convert (Now cascades from model)
torzdf Jan 26, 2019
ed9382a
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 26, 2019
161e276
Implement convert for all model types
torzdf Jan 26, 2019
8cb4fcf
Add mask option and coverage option to all existing models
torzdf Jan 27, 2019
316a6f9
bugfix for model loading on convert
torzdf Jan 27, 2019
ab053a4
debug print removal
kvrooman Jan 27, 2019
53ea7dc
Bugfix for masks in dfl_h128 and iae
torzdf Jan 28, 2019
d7fa42c
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Jan 28, 2019
6d70c89
Update preview display. Add preview scaling to cli
torzdf Jan 29, 2019
77366a5
mask notes
kvrooman Jan 29, 2019
11011b3
Delete training_data_v2.py
kvrooman Jan 29, 2019
8f2bf88
training data variables
kvrooman Jan 29, 2019
d691b05
Fix timelapse function
torzdf Jan 29, 2019
7182805
Add new config items to state file for legacy purposes
torzdf Jan 29, 2019
9fcf81b
Slight GUI tweak
torzdf Jan 29, 2019
2ddf864
Raise exception if problem with loaded model
torzdf Jan 29, 2019
6cc5852
Add Tensorboard support (Logs stored in model directory)
torzdf Jan 29, 2019
3ac4b81
ICNR fix
kvrooman Jan 29, 2019
b3757a8
loss bugfix
torzdf Jan 30, 2019
27fa627
convert bugfix
torzdf Jan 30, 2019
806189f
Move ini files to config folder. Make TensorBoard optional
torzdf Jan 30, 2019
b2011c6
Fix training data for unbalanced inputs/outputs
torzdf Jan 31, 2019
16a53bb
Fix config "none" test
torzdf Jan 31, 2019
732bd54
Keep helptext in .ini files when saving config from GUI
torzdf Jan 31, 2019
61bc67d
Remove frame_dims from alignments
torzdf Jan 31, 2019
a8ba677
Add no-flip and warp-to-landmarks cli options
torzdf Feb 1, 2019
19a05b0
Revert OHR to RC4_fix version
torzdf Feb 1, 2019
7adc197
Fix lowmem mode on OHR model
torzdf Feb 1, 2019
83c2a3b
padding to variable
kvrooman Feb 2, 2019
e1006e8
Save models in parallel threads
torzdf Feb 3, 2019
9357926
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Feb 3, 2019
b08b920
Speed-up of res_block stability
kvrooman Feb 3, 2019
568cce6
Automated Reflection Padding
kvrooman Feb 3, 2019
d37ab8f
Reflect Padding as a training option
kvrooman Feb 3, 2019
7d48ab5
rest of reflect padding
kvrooman Feb 4, 2019
6597492
Move TB logging to cli. Session info to state file
torzdf Feb 4, 2019
40b8226
Merge branch 'train_refactor' of https://github.com/deepfakes/faceswa…
torzdf Feb 4, 2019
0363a5f
Add session iterations to state file
torzdf Feb 5, 2019
4da4142
Add recent files to menu. GUI code tidy up
torzdf Feb 6, 2019
fb631a5
[GUI] Fix recent file list update issue
torzdf Feb 6, 2019
42608be
Add correct loss names to TensorBoard logs
torzdf Feb 6, 2019
d4efa15
Update live graph to use TensorBoard and remove animation
torzdf Feb 7, 2019
61c7eda
Fix analysis tab. GUI optimizations
torzdf Feb 7, 2019
3af6e60
Analysis Graph popup to Tensorboard Logs
torzdf Feb 8, 2019
e929754
Merge branch 'staging' into train_refactor
torzdf Feb 8, 2019
af49107
[GUI] Bug fix for graphing for models with hypens in name
torzdf Feb 8, 2019
c32be6e
[GUI] Correctly split loss to tabs during training
torzdf Feb 8, 2019
9303491
[GUI] Add loss type selection to analysis graph
torzdf Feb 8, 2019
60b58da
Fix store command name in recent files. Switch to correct tab on open
torzdf Feb 9, 2019
21d2834
[GUI] Disable training graph when 'no-logs' is selected
torzdf Feb 9, 2019
8707936
Merge branch 'master' into train_refactor
torzdf Feb 9, 2019
a768d6b
Fix graphing race condition
torzdf Feb 9, 2019
8bcf59f
rename original_hires model to unbalanced
torzdf Feb 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
dfl mask. Make masks selectable in config (#575)
* DFL H128 Mask. Mask type selectable in config.
  • Loading branch information
torzdf authored Jan 3, 2019
commit 5d21061726f8f571774e2ea1b9eed82761ec026d
1 change: 1 addition & 0 deletions lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def validate_config(self):
opt["helptext"],
new_config)
self.config = new_config
self.config.optionxform = str
self.save_config()
logger.debug("Updated config")

Expand Down
38 changes: 36 additions & 2 deletions lib/model/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
Masks from:
dfaker: https://github.com/dfaker/df"""

import logging

import cv2
import numpy as np

from lib.umeyama import umeyama
from lib.aligner import LANDMARKS_2D

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def dfaker_mask(landmarks, coverage, face):
""" Dfaker model mask """
def dfaker(landmarks, face, **kwargs):
""" Dfaker model mask
Embeds the mask into the face alpha channel """
coverage = kwargs["coverage"]
logger.trace("face_shape: %s, coverage: %s, landmarks: %s", face.shape, coverage, landmarks)
size = face.shape[0] - 1

mat = umeyama(landmarks[17:], LANDMARKS_2D, True)[0:2]
Expand All @@ -37,4 +44,31 @@ def dfaker_mask(landmarks, coverage, face):
mask = mask[:, :, 0]
merged = np.dstack([face, mask]).astype(np.uint8)

logger.trace("Returning: face_shape: %s", merged.shape)
return merged


def dfl_full(landmarks, face, **kwargs):
""" DFL Face Full Mask """
logger.trace("face_shape: %s, landmarks: %s", face.shape, landmarks)
hull_mask = np.zeros(face.shape[0:2] + (1, ), dtype=np.float32)
hull1 = cv2.convexHull(np.concatenate((landmarks[0:17], # pylint: disable=no-member
landmarks[48:],
[landmarks[0]],
[landmarks[8]],
[landmarks[16]])))
hull2 = cv2.convexHull(np.concatenate((landmarks[27:31], # pylint: disable=no-member
[landmarks[33]])))
hull3 = cv2.convexHull(np.concatenate((landmarks[17:27], # pylint: disable=no-member
[landmarks[0]],
[landmarks[27]],
[landmarks[16]],
[landmarks[33]])))

cv2.fillConvexPoly(hull_mask, hull1, (1, )) # pylint: disable=no-member
cv2.fillConvexPoly(hull_mask, hull2, (1, )) # pylint: disable=no-member
cv2.fillConvexPoly(hull_mask, hull3, (1, )) # pylint: disable=no-member

face = np.concatenate((face, hull_mask), -1)
logger.trace("Returning: face_shape: %s", face.shape)
return face
121 changes: 81 additions & 40 deletions lib/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import numpy as np
from scipy.interpolate import griddata

from lib.model import masks
from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.umeyama import umeyama
from lib.model.masks import dfaker_mask

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

# TODO Add source, dest points to random warp half and ability to
# not have landmarks to random warp full


class TrainingDataGenerator():
""" Generate training data for models """
Expand All @@ -27,9 +30,22 @@ def __init__(self, transform_kwargs, training_opts=dict()):
{key: val for key, val in training_opts.items() if key != "landmarks"})
self.batchsize = 0
self.training_opts = training_opts
self.full_face = self.training_opts.get("full_face", False)
self.mask_function = self.set_mask_function()
self.processing = ImageManipulation(**transform_kwargs)
logger.debug("Initialized %s", self.__class__.__name__)

def set_mask_function(self):
""" Set the mask function to use if using mask """
mask_type = self.training_opts.get("mask_type", None)
if mask_type:
logger.debug("Mask type: '%s'", mask_type)
mask_func = getattr(masks, mask_type)
else:
mask_func = None
logger.debug("Mask function: %s", mask_func)
return mask_func

def minibatch_ab(self, images, batchsize, side, do_shuffle=True):
""" Keep a queue filled to 8x Batch Size """
logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s)",
Expand Down Expand Up @@ -94,32 +110,29 @@ def minibatch(self, q_name, load_thread):

def process_face(self, filename, side):
""" Load an image and perform transformation and warping """
logger.trace("Process face: (filename: '%s', side: '%s'", filename, side)
logger.trace("Process face: (filename: '%s', side: '%s')", filename, side)
try:
image = cv2.imread(filename) # pylint: disable=no-member
except TypeError:
raise Exception("Error while reading image", filename)

landmarks = self.training_opts.get("landmarks", None)
src_pts = self.get_landmarks(filename,
image,
side,
landmarks) if landmarks else None

if self.training_opts.get("use_mask", False):
image = dfaker_mask(src_pts, self.processing.coverage, image)
if self.mask_function:
landmarks = self.training_opts["landmarks"]
src_pts = self.get_landmarks(filename, image, side, landmarks)
image = self.mask_function(src_pts, image, coverage=self.processing.coverage)

image = self.processing.color_adjust(image)
image = self.processing.random_transform(image)

if landmarks:
if self.full_face:
dst_pts = self.get_closest_match(filename, side, landmarks, src_pts)
processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts)
processed = self.processing.random_warp_full_face(image, src_pts, dst_pts)
else:
processed = self.processing.random_warp(image)

retval = self.processing.do_random_flip(processed)
logger.trace("Processed face: (filename: '%s', side: '%s'", filename, side)
logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)",
filename, side, [img.shape for img in retval])
return retval

@staticmethod
Expand Down Expand Up @@ -153,7 +166,7 @@ def get_closest_match(filename, side, landmarks, src_points):
class ImageManipulation():
""" Manipulations to be performed on training images """
def __init__(self, rotation_range=10, zoom_range=0.05, shift_range=0.05, random_flip=0.4,
zoom=1, coverage=160, scale=5):
zoom=(1, 1), coverage=160, scale=5):
""" rotation_range: Used for random transform
zoom_range: Used for random transform
shift_range: Used for random transform
Expand Down Expand Up @@ -184,6 +197,16 @@ def color_adjust(img):
logger.trace("Color adjusting image")
return img / 255.0

@staticmethod
def seperate_mask(image):
""" Return the image and the mask from a 4 channel image """
mask = None
if image.shape[2] == 4:
logger.trace("Image contains mask")
mask = image[:, :, 3].reshape((image.shape[0], image.shape[1], 1))
image = image[:, :, :3]
return image, mask

def random_transform(self, image):
""" Randomly transform an image """
logger.trace("Randomly transforming image")
Expand All @@ -204,9 +227,10 @@ def random_transform(self, image):
logger.trace("Randomly transformed image")
return result

def random_warp(self, image):
def random_warp(self, image, src_points=None, dst_points=None):
""" get pair of random warped images from aligned face image """
logger.trace("Randomly warping image")
image, mask = self.seperate_mask(image)
height, width = image.shape[0:2]
assert height == width and height % 2 == 0

Expand All @@ -219,31 +243,45 @@ def random_warp(self, image):
mapy = mapy + np.random.normal(size=(self.scale, self.scale), scale=self.scale)

interp_mapx = cv2.resize( # pylint: disable=no-member
mapx, (80 * self.zoom, 80 * self.zoom))[8 * self.zoom:72 * self.zoom,
8 * self.zoom:72 * self.zoom].astype('float32')
mapx, (80 * self.zoom[0],
80 * self.zoom[0]))[8 * self.zoom[0]:72 * self.zoom[0],
8 * self.zoom[0]:72 * self.zoom[0]].astype('float32')
interp_mapy = cv2.resize( # pylint: disable=no-member
mapy, (80 * self.zoom, 80 * self.zoom))[8 * self.zoom:72 * self.zoom,
8 * self.zoom:72 * self.zoom].astype('float32')
mapy, (80 * self.zoom[0],
80 * self.zoom[0]))[8 * self.zoom[0]:72 * self.zoom[0],
8 * self.zoom[0]:72 * self.zoom[0]].astype('float32')

warped_image = cv2.remap( # pylint: disable=no-member
image, interp_mapx, interp_mapy, cv2.INTER_LINEAR) # pylint: disable=no-member
logger.trace("Warped image shape: %s", warped_image.shape)

src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = np.mgrid[0:65 * self.zoom:16 * self.zoom,
0:65 * self.zoom:16 * self.zoom].T.reshape(-1, 2)
dst_points = np.mgrid[0:65 * self.zoom[0]:16 * self.zoom[0],
0:65 * self.zoom[0]:16 * self.zoom[0]].T.reshape(-1, 2)

mat = umeyama(src_points, dst_points, True)[0:2]
target_image = cv2.warpAffine( # pylint: disable=no-member
image, mat, (64 * self.zoom, 64 * self.zoom))
image, mat, (64 * self.zoom[1], 64 * self.zoom[1]))
logger.trace("Target image shape: %s", target_image.shape)

retval = [warped_image, target_image]

if mask is not None:
target_mask = cv2.warpAffine( # pylint: disable=no-member
mask, mat, (64 * self.zoom[1], 64 * self.zoom[1]))
target_mask = target_mask.reshape((64 * self.zoom[1], 64 * self.zoom[1], 1))
logger.trace("Target mask shape: %s", target_mask.shape)

retval.append(target_mask)

retval = warped_image, target_image
logger.trace("Randomly warped image")
return retval

def random_warp_landmarks(self, image, src_points, dst_points):
def random_warp_full_face(self, image, src_points=None, dst_points=None):
""" get warped image, target image and target mask
From DFAKER plugin """
logger.trace("Randomly warping landmarks")
image, mask = self.seperate_mask(image)
size = image.shape[0]
p_mx = size - 1
p_hf = (size // 2) - 1
Expand Down Expand Up @@ -286,36 +324,39 @@ def random_warp_landmarks(self, image, src_points, dst_points):
map_x_32 = map_x.astype('float32')
map_y_32 = map_y.astype('float32')

warped_image = cv2.remap(image[:, :, :3], # pylint: disable=no-member
warped_image = cv2.remap(image, # pylint: disable=no-member
map_x_32,
map_y_32,
cv2.INTER_LINEAR, # pylint: disable=no-member
cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
target_image = image[:, :, :3]
target_mask = None
if image.shape[2] == 4:
target_mask = image[:, :, 3].reshape((image.shape[0], image.shape[1], 1))
target_image = image

pad_lt = (64 * self.zoom) - (60 * self.zoom)
pad_rb = (64 * self.zoom) + (60 * self.zoom)
pad_lt = size // 32 # 8px on a 256px image
pad_rb = size - pad_lt

warped_image = cv2.resize( # pylint: disable=no-member
warped_image[pad_lt:pad_rb, pad_lt:pad_rb, :],
(64, 64),
(64 * self.zoom[0], 64 * self.zoom[0]),
cv2.INTER_AREA) # pylint: disable=no-member
logger.trace("Warped image shape: %s", warped_image.shape)
target_image = cv2.resize( # pylint: disable=no-member
target_image[pad_lt:pad_rb, pad_lt:pad_rb, :],
(64 * self.zoom, 64 * self.zoom),
(64 * self.zoom[1], 64 * self.zoom[1]),
cv2.INTER_AREA) # pylint: disable=no-member
if target_mask is None:
retval = warped_image, target_image
else:
logger.trace("Target image shape: %s", target_image.shape)

retval = [warped_image, target_image]

if mask is not None:
target_mask = cv2.resize( # pylint: disable=no-member
target_mask[pad_lt:pad_rb, pad_lt:pad_rb, :],
(64 * self.zoom, 64 * self.zoom),
mask[pad_lt:pad_rb, pad_lt:pad_rb, :],
(64 * self.zoom[1], 64 * self.zoom[1]),
cv2.INTER_AREA) # pylint: disable=no-member
target_mask = target_mask.reshape((64 * self.zoom, 64 * self.zoom, 1))
retval = warped_image, target_image, target_mask
target_mask = target_mask.reshape((64 * self.zoom[1], 64 * self.zoom[1], 1))
logger.trace("Target mask shape: %s", target_mask.shape)

retval.append(target_mask)

logger.trace("Randomly warped image")
return retval

Expand Down
15 changes: 15 additions & 0 deletions plugins/train/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def set_defaults(self):
"be named 'alignments.<file extension>' (eg. "
"alignments.json)."
"\nChoose from: 'json', 'pickle' or 'yaml'")
self.add_item(
section=section, title="mask_type", datatype=str, default="dfaker",
info="The mask to be used for training."
"\nChoose from: 'dfaker' or 'dfl_full'")

# << DFL MODEL OPTIONS >> #
section = "dfl_h128"
Expand All @@ -92,6 +96,17 @@ def set_defaults(self):
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
"with a changed lowmem mode are not compatible with each other."
"\nChoose from: True, False")
self.add_item(
section=section, title="alignments_format", datatype=str, default="json",
info="DFL-H128 model requires the alignments for your training "
"images to be avalaible within the FACES folder.\nIt should "
"be named 'alignments.<file extension>' (eg. "
"alignments.json)."
"\nChoose from: 'json', 'pickle' or 'yaml'")
self.add_item(
section=section, title="mask_type", datatype=str, default="dfl_full",
info="The mask to be used for training."
"\nChoose from: 'dfaker' or 'dfl_full'")

# << GAN MODEL OPTIONS >> #
section = "gan_v2_2"
Expand Down
3 changes: 2 additions & 1 deletion plugins/train/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys

from json import JSONDecodeError
from keras import losses
from keras.optimizers import Adam
from keras.utils import multi_gpu_model

Expand Down Expand Up @@ -125,7 +126,7 @@ def loss_function(self):
if self.config["dssim_loss"]:
loss_func = DSSIMObjective()
else:
loss_func = "mean_absolute_error"
loss_func = losses.mean_absolute_error
logger.debug(loss_func)
return loss_func

Expand Down
7 changes: 3 additions & 4 deletions plugins/train/model/dfaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ def set_training_data(self):
""" Set the dictionary for training """
logger.debug("Setting training data")
training_opts = dict()
serializer = self.config["alignments_format"]
training_opts["serializer"] = serializer
training_opts["use_mask"] = True
training_opts["use_alignments"] = True
training_opts["serializer"] = self.config["alignments_format"]
training_opts["mask_type"] = self.config["mask_type"]
training_opts["full_face"] = True
training_opts["preview_images"] = 10
logger.debug("Set training data: %s", training_opts)
return training_opts
Expand Down
4 changes: 1 addition & 3 deletions plugins/train/model/dfl_h128.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .original import get_config, logger, Model as OriginalModel

# TODO Implement DFL loss function (currently using dfaker)
# TODO Get Mask working


class Model(OriginalModel):
Expand All @@ -31,8 +30,7 @@ def set_training_data(self):
""" Set the dictionary for training """
logger.debug("Setting training data")
training_opts = dict()
training_opts["use_mask"] = True
training_opts["remove_alpha"] = True
training_opts["mask_type"] = self.config["mask_type"]
training_opts["preview_images"] = 10
logger.debug("Set training data: %s", training_opts)
return training_opts
Expand Down
1 change: 0 additions & 1 deletion plugins/train/model/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def initialize(self):
inp,
self.networks["decoder_b"].network(self.networks["encoder"].network(inp)))
self.add_predictors(ae_a, ae_b)

logger.debug("Initialized model")

def encoder(self):
Expand Down
Loading