diff --git a/Face_Detection/align_warp_back_multiple_dlib_HR.py b/Face_Detection/align_warp_back_multiple_dlib_HR.py new file mode 100644 index 00000000..f3711c96 --- /dev/null +++ b/Face_Detection/align_warp_back_multiple_dlib_HR.py @@ -0,0 +1,437 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import numpy as np +import skimage.io as io + +# from face_sdk import FaceDetection +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from skimage.transform import SimilarityTransform +from skimage.transform import warp +from PIL import Image, ImageFilter +import torch.nn.functional as F +import torchvision as tv +import torchvision.utils as vutils +import time +import cv2 +import os +from skimage import img_as_ubyte +import json +import argparse +import dlib + + +def calculate_cdf(histogram): + """ + This method calculates the cumulative distribution function + :param array histogram: The values of the histogram + :return: normalized_cdf: The normalized cumulative distribution function + :rtype: array + """ + # Get the cumulative sum of the elements + cdf = histogram.cumsum() + + # Normalize the cdf + normalized_cdf = cdf / float(cdf.max()) + + return normalized_cdf + + +def calculate_lookup(src_cdf, ref_cdf): + """ + This method creates the lookup table + :param array src_cdf: The cdf for the source image + :param array ref_cdf: The cdf for the reference image + :return: lookup_table: The lookup table + :rtype: array + """ + lookup_table = np.zeros(256) + lookup_val = 0 + for src_pixel_val in range(len(src_cdf)): + lookup_val + for ref_pixel_val in range(len(ref_cdf)): + if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: + lookup_val = ref_pixel_val + break + lookup_table[src_pixel_val] = lookup_val + return lookup_table + + +def match_histograms(src_image, ref_image): + """ + This method matches the source image histogram to the + reference signal + :param image src_image: The original source image + :param image ref_image: The reference image + :return: image_after_matching + :rtype: image (array) + """ + # Split the images into the different color channels + # b means blue, g means green and r means red + src_b, src_g, src_r = cv2.split(src_image) + ref_b, ref_g, ref_r = cv2.split(ref_image) + + # Compute the b, g, and r histograms separately + # The flatten() Numpy method returns a copy of the array c + # collapsed into one dimension. + src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) + src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) + src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) + ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) + ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) + ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) + + # Compute the normalized cdf for the source and reference image + src_cdf_blue = calculate_cdf(src_hist_blue) + src_cdf_green = calculate_cdf(src_hist_green) + src_cdf_red = calculate_cdf(src_hist_red) + ref_cdf_blue = calculate_cdf(ref_hist_blue) + ref_cdf_green = calculate_cdf(ref_hist_green) + ref_cdf_red = calculate_cdf(ref_hist_red) + + # Make a separate lookup table for each color + blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) + green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) + red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) + + # Use the lookup function to transform the colors of the original + # source image + blue_after_transform = cv2.LUT(src_b, blue_lookup_table) + green_after_transform = cv2.LUT(src_g, green_lookup_table) + red_after_transform = cv2.LUT(src_r, red_lookup_table) + + # Put the image back together + image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform]) + image_after_matching = cv2.convertScaleAbs(image_after_matching) + + return image_after_matching + + +def _standard_face_pts(): + pts = ( + np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 + - 1.0 + ) + + return np.reshape(pts, (5, 2)) + + +def _origin_face_pts(): + pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) + + return np.reshape(pts, (5, 2)) + + +def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(target_pts, landmark) + + return affine + + +def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(landmark, target_pts) + + return affine + + +def show_detection(image, box, landmark): + plt.imshow(image) + print(box[2] - box[0]) + plt.gca().add_patch( + Rectangle( + (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" + ) + ) + plt.scatter(landmark[0][0], landmark[0][1]) + plt.scatter(landmark[1][0], landmark[1][1]) + plt.scatter(landmark[2][0], landmark[2][1]) + plt.scatter(landmark[3][0], landmark[3][1]) + plt.scatter(landmark[4][0], landmark[4][1]) + plt.show() + + +def affine2theta(affine, input_w, input_h, target_w, target_h): + # param = np.linalg.inv(affine) + param = affine + theta = np.zeros([2, 3]) + theta[0, 0] = param[0, 0] * input_h / target_h + theta[0, 1] = param[0, 1] * input_w / target_h + theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 + theta[1, 0] = param[1, 0] * input_h / target_w + theta[1, 1] = param[1, 1] * input_w / target_w + theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 + return theta + + +def blur_blending(im1, im2, mask): + + mask *= 255.0 + + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + + mask = Image.fromarray(mask.astype("uint8")).convert("L") + im1 = Image.fromarray(im1.astype("uint8")) + im2 = Image.fromarray(im2.astype("uint8")) + + mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) + im = Image.composite(im1, im2, mask) + + im = Image.composite(im, im2, mask_blur) + + return np.array(im) / 255.0 + + +def blur_blending_cv2(im1, im2, mask): + + mask *= 255.0 + + kernel = np.ones((9, 9), np.uint8) + mask = cv2.erode(mask, kernel, iterations=3) + + mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) + mask_blur /= 255.0 + + im = im1 * mask_blur + (1 - mask_blur) * im2 + + im /= 255.0 + im = np.clip(im, 0.0, 1.0) + + return im + + +# def Poisson_blending(im1,im2,mask): + + +# Image.composite( +def Poisson_blending(im1, im2, mask): + + # mask=1-mask + mask *= 255 + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + mask /= 255 + mask = 1 - mask + mask *= 255 + + mask = mask[:, :, 0] + width, height, channels = im1.shape + center = (int(height / 2), int(width / 2)) + result = cv2.seamlessClone( + im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE + ) + + return result / 255.0 + + +def Poisson_B(im1, im2, mask, center): + + mask *= 255 + + result = cv2.seamlessClone( + im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE + ) + + return result / 255 + + +def seamless_clone(old_face, new_face, raw_mask): + + height, width, _ = old_face.shape + height = height // 2 + width = width // 2 + + y_indices, x_indices, _ = np.nonzero(raw_mask) + y_crop = slice(np.min(y_indices), np.max(y_indices)) + x_crop = slice(np.min(x_indices), np.max(x_indices)) + y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) + x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) + + insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask[insertion_mask != 0] = 255 + prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype( + "uint8" + ) + # if np.sum(insertion_mask) == 0: + n_mask = insertion_mask[1:-1, 1:-1, :] + n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) + print(n_mask.shape) + x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) + if w < 4 or h < 4: + blended = prior + else: + blended = cv2.seamlessClone( + insertion, # pylint: disable=no-member + prior, + insertion_mask, + (x_center, y_center), + cv2.NORMAL_CLONE, + ) # pylint: disable=no-member + + blended = blended[height:-height, width:-width] + + return blended.astype("float32") / 255.0 + + +def get_landmark(face_landmarks, id): + part = face_landmarks.part(id) + x = part.x + y = part.y + + return (x, y) + + +def search(face_landmarks): + + x1, y1 = get_landmark(face_landmarks, 36) + x2, y2 = get_landmark(face_landmarks, 39) + x3, y3 = get_landmark(face_landmarks, 42) + x4, y4 = get_landmark(face_landmarks, 45) + + x_nose, y_nose = get_landmark(face_landmarks, 30) + + x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) + x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) + + x_left_eye = int((x1 + x2) / 2) + y_left_eye = int((y1 + y2) / 2) + x_right_eye = int((x3 + x4) / 2) + y_right_eye = int((y3 + y4) / 2) + + results = np.array( + [ + [x_left_eye, y_left_eye], + [x_right_eye, y_right_eye], + [x_nose, y_nose], + [x_left_mouth, y_left_mouth], + [x_right_mouth, y_right_mouth], + ] + ) + + return results + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--origin_url", type=str, default="./", help="origin images") + parser.add_argument("--replace_url", type=str, default="./", help="restored faces") + parser.add_argument("--save_url", type=str, default="./save") + opts = parser.parse_args() + + origin_url = opts.origin_url + replace_url = opts.replace_url + save_url = opts.save_url + + if not os.path.exists(save_url): + os.makedirs(save_url) + + face_detector = dlib.get_frontal_face_detector() + landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") + + count = 0 + + for x in os.listdir(origin_url): + img_url = os.path.join(origin_url, x) + pil_img = Image.open(img_url).convert("RGB") + + origin_width, origin_height = pil_img.size + image = np.array(pil_img) + + start = time.time() + faces = face_detector(image) + done = time.time() + + if len(faces) == 0: + print("Warning: There is no face in %s" % (x)) + continue + + blended = image + for face_id in range(len(faces)): + + current_face = faces[face_id] + face_landmarks = landmark_locator(image, current_face) + current_fl = search(face_landmarks) + + forward_mask = np.ones_like(image).astype("uint8") + affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) + aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True) + forward_mask = warp( + forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True + ) + + affine_inverse = affine.inverse + cur_face = aligned_face + if replace_url != "": + + face_name = x[:-4] + "_" + str(face_id + 1) + ".png" + cur_url = os.path.join(replace_url, face_name) + restored_face = Image.open(cur_url).convert("RGB") + restored_face = np.array(restored_face) + cur_face = restored_face + + ## Histogram Color matching + A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) + B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) + B = match_histograms(B, A) + cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) + + warped_back = warp( + cur_face, + affine_inverse, + output_shape=(origin_height, origin_width, 3), + order=3, + preserve_range=True, + ) + + backward_mask = warp( + forward_mask, + affine_inverse, + output_shape=(origin_height, origin_width, 3), + order=0, + preserve_range=True, + ) ## Nearest neighbour + + blended = blur_blending_cv2(warped_back, blended, backward_mask) + blended *= 255.0 + + io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0)) + + count += 1 + + if count % 1000 == 0: + print("%d have finished ..." % (count)) + diff --git a/Face_Detection/detect_all_dlib_HR.py b/Face_Detection/detect_all_dlib_HR.py new file mode 100644 index 00000000..f52e149b --- /dev/null +++ b/Face_Detection/detect_all_dlib_HR.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import numpy as np +import skimage.io as io + +# from FaceSDK.face_sdk import FaceDetection +# from face_sdk import FaceDetection +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from skimage.transform import SimilarityTransform +from skimage.transform import warp +from PIL import Image +import torch.nn.functional as F +import torchvision as tv +import torchvision.utils as vutils +import time +import cv2 +import os +from skimage import img_as_ubyte +import json +import argparse +import dlib + + +def _standard_face_pts(): + pts = ( + np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 + - 1.0 + ) + + return np.reshape(pts, (5, 2)) + + +def _origin_face_pts(): + pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) + + return np.reshape(pts, (5, 2)) + + +def get_landmark(face_landmarks, id): + part = face_landmarks.part(id) + x = part.x + y = part.y + + return (x, y) + + +def search(face_landmarks): + + x1, y1 = get_landmark(face_landmarks, 36) + x2, y2 = get_landmark(face_landmarks, 39) + x3, y3 = get_landmark(face_landmarks, 42) + x4, y4 = get_landmark(face_landmarks, 45) + + x_nose, y_nose = get_landmark(face_landmarks, 30) + + x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) + x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) + + x_left_eye = int((x1 + x2) / 2) + y_left_eye = int((y1 + y2) / 2) + x_right_eye = int((x3 + x4) / 2) + y_right_eye = int((y3 + y4) / 2) + + results = np.array( + [ + [x_left_eye, y_left_eye], + [x_right_eye, y_right_eye], + [x_nose, y_nose], + [x_left_mouth, y_left_mouth], + [x_right_mouth, y_right_mouth], + ] + ) + + return results + + +def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(target_pts, landmark) + + return affine.params + + +def show_detection(image, box, landmark): + plt.imshow(image) + print(box[2] - box[0]) + plt.gca().add_patch( + Rectangle( + (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" + ) + ) + plt.scatter(landmark[0][0], landmark[0][1]) + plt.scatter(landmark[1][0], landmark[1][1]) + plt.scatter(landmark[2][0], landmark[2][1]) + plt.scatter(landmark[3][0], landmark[3][1]) + plt.scatter(landmark[4][0], landmark[4][1]) + plt.show() + + +def affine2theta(affine, input_w, input_h, target_w, target_h): + # param = np.linalg.inv(affine) + param = affine + theta = np.zeros([2, 3]) + theta[0, 0] = param[0, 0] * input_h / target_h + theta[0, 1] = param[0, 1] * input_w / target_h + theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 + theta[1, 0] = param[1, 0] * input_h / target_w + theta[1, 1] = param[1, 1] * input_w / target_w + theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 + return theta + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input") + parser.add_argument( + "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output" + ) + opts = parser.parse_args() + + url = opts.url + save_url = opts.save_url + + ### If the origin url is None, then we don't need to reid the origin image + + os.makedirs(url, exist_ok=True) + os.makedirs(save_url, exist_ok=True) + + face_detector = dlib.get_frontal_face_detector() + landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") + + count = 0 + + map_id = {} + for x in os.listdir(url): + img_url = os.path.join(url, x) + pil_img = Image.open(img_url).convert("RGB") + + image = np.array(pil_img) + + start = time.time() + faces = face_detector(image) + done = time.time() + + if len(faces) == 0: + print("Warning: There is no face in %s" % (x)) + continue + + print(len(faces)) + + if len(faces) > 0: + for face_id in range(len(faces)): + current_face = faces[face_id] + face_landmarks = landmark_locator(image, current_face) + current_fl = search(face_landmarks) + + affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) + aligned_face = warp(image, affine, output_shape=(512, 512, 3)) + img_name = x[:-4] + "_" + str(face_id + 1) + io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face)) + + count += 1 + + if count % 1000 == 0: + print("%d have finished ..." % (count)) + diff --git a/Face_Enhancement/models/networks/generator.py b/Face_Enhancement/models/networks/generator.py index c9f168dc..6e24cadc 100644 --- a/Face_Enhancement/models/networks/generator.py +++ b/Face_Enhancement/models/networks/generator.py @@ -97,7 +97,7 @@ def compute_latent_vector_size(self, opt): else: raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers) - sw = opt.crop_size // (2 ** num_up_layers) + sw = opt.load_size // (2 ** num_up_layers) sh = round(sw / opt.aspect_ratio) return sw, sh diff --git a/Global/options/test_options.py b/Global/options/test_options.py index 08904552..67e2e3a7 100755 --- a/Global/options/test_options.py +++ b/Global/options/test_options.py @@ -97,4 +97,4 @@ def initialize(self): self.parser.add_argument( "--Scratch_and_Quality_restore", action="store_true", help="For scratched images" ) - + self.parser.add_argument("--HR", action='store_true',help='Large input size with scratches') diff --git a/Global/test.py b/Global/test.py index 5ef76655..9fb4386c 100644 --- a/Global/test.py +++ b/Global/test.py @@ -86,6 +86,11 @@ def parameter_set(opt): opt.name = "mapping_scratch" opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") + if opt.HR: + opt.mapping_exp = 1 + opt.inference_optimize = True + opt.mask_dilation = 3 + opt.name = "mapping_Patch_Attention" if __name__ == "__main__": @@ -135,6 +140,11 @@ def parameter_set(opt): if opt.NL_use_mask: mask_name = mask_loader[i] mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") + if opt.mask_dilation!=0: + kernel=np.ones((3,3),np.uint8) + mask=np.array(mask) + mask=cv2.dilate(mask,kernel,iterations=opt.mask_dilation) + mask=Image.fromarray(mask.astype('uint8')) origin = input input = irregular_hole_synthesize(input, mask) mask = mask_transform(mask) diff --git a/README.md b/README.md index 8a8a9c2f..4a64c3c4 100755 --- a/README.md +++ b/README.md @@ -23,6 +23,10 @@ The code originates from our research project and the aim is to demonstrate the **We are improving the algorithm so as to process high resolution photos. It takes time and please stay tuned.** ## News +The framework now supports the restoration of high-resolution input. + + + Training code is available and welcome to have a try and learn the training details. You can now play with our [Colab](https://colab.research.google.com/drive/1NEm6AsybIiC5TwTU_4DqDkQO0nFRB-uA?usp=sharing) and try it on your photos. @@ -101,6 +105,16 @@ python run.py --input_folder [test_image_folder_path] \ --with_scratch ``` +For high-resolution images with scratches: + +``` +python run.py --input_folder [test_image_folder_path] \ + --output_folder [output_path] \ + --GPU 0 \ + --with_scratch \ + --HR +``` + Note: Please try to use the absolute path. The final results will be saved in `./output_path/final_output/`. You could also check the produced results of different steps in `output_path`. ### 2) Scratch Detection @@ -132,8 +146,8 @@ python test.py --Scratch_and_Quality_restore \ --outputs_dir [output_path] python test.py --Quality_restore \ - --test_input [test_image_folder_path] \ - --outputs_dir [output_path] + --test_input [test_image_folder_path] \ + --outputs_dir [output_path] ``` @@ -203,14 +217,17 @@ Traing the mapping with scraches: python train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_scratch --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file] ``` - +Traing the mapping with scraches (Multi-Scale Patch Attention for HR input): +``` +python train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_Pathc_Attention --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file] --mapping_exp 1 +``` ## To Do - [x] Clean testing code - [x] Release pretrained model - [x] Collab demo -- [ ] Replace face detection module (dlib) with RetinaFace - [x] Release training code +- [x] Processing of high-resolution input ## Citation diff --git a/imgs/HR.png b/imgs/HR.png new file mode 100644 index 00000000..f0b363c5 Binary files /dev/null and b/imgs/HR.png differ diff --git a/run.py b/run.py index aaf892b3..d7844860 100644 --- a/run.py +++ b/run.py @@ -29,6 +29,7 @@ def run_cmd(command): "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint" ) parser.add_argument("--with_scratch", action="store_true") + parser.add_argument("--HR", action='store_true') opts = parser.parse_args() gpu1 = opts.GPU @@ -73,6 +74,12 @@ def run_cmd(command): + " --GPU " + gpu1 ) + + if opts.HR: + HR_suffix=" --HR" + else: + HR_suffix="" + stage_1_command_2 = ( "python test.py --Scratch_and_Quality_restore --test_input " + new_input @@ -81,7 +88,7 @@ def run_cmd(command): + " --outputs_dir " + stage_1_output_dir + " --gpu_ids " - + gpu1 + + gpu1 + HR_suffix ) run_cmd(stage_1_command_1) @@ -107,9 +114,14 @@ def run_cmd(command): stage_2_output_dir = os.path.join(opts.output_folder, "stage_2_detection_output") if not os.path.exists(stage_2_output_dir): os.makedirs(stage_2_output_dir) - stage_2_command = ( - "python detect_all_dlib.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir - ) + if opts.HR: + stage_2_command = ( + "python detect_all_dlib_HR.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir + ) + else: + stage_2_command = ( + "python detect_all_dlib.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir + ) run_cmd(stage_2_command) print("Finish Stage 2 ...") print("\n") @@ -122,19 +134,36 @@ def run_cmd(command): stage_3_output_dir = os.path.join(opts.output_folder, "stage_3_face_output") if not os.path.exists(stage_3_output_dir): os.makedirs(stage_3_output_dir) - stage_3_command = ( - "python test_face.py --old_face_folder " - + stage_3_input_face - + " --old_face_label_folder " - + stage_3_input_mask - + " --tensorboard_log --name " - + opts.checkpoint_name - + " --gpu_ids " - + gpu1 - + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir " - + stage_3_output_dir - + " --no_parsing_map" - ) + + if opts.HR: + opts.checkpoint_name='FaceSR_512' + stage_3_command = ( + "python test_face.py --old_face_folder " + + stage_3_input_face + + " --old_face_label_folder " + + stage_3_input_mask + + " --tensorboard_log --name " + + opts.checkpoint_name + + " --gpu_ids " + + gpu1 + + " --load_size 512 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 1 --results_dir " + + stage_3_output_dir + + " --no_parsing_map" + ) + else: + stage_3_command = ( + "python test_face.py --old_face_folder " + + stage_3_input_face + + " --old_face_label_folder " + + stage_3_input_mask + + " --tensorboard_log --name " + + opts.checkpoint_name + + " --gpu_ids " + + gpu1 + + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir " + + stage_3_output_dir + + " --no_parsing_map" + ) run_cmd(stage_3_command) print("Finish Stage 3 ...") print("\n") @@ -147,14 +176,24 @@ def run_cmd(command): stage_4_output_dir = os.path.join(opts.output_folder, "final_output") if not os.path.exists(stage_4_output_dir): os.makedirs(stage_4_output_dir) - stage_4_command = ( - "python align_warp_back_multiple_dlib.py --origin_url " - + stage_4_input_image_dir - + " --replace_url " - + stage_4_input_face_dir - + " --save_url " - + stage_4_output_dir - ) + if opts.HR: + stage_4_command = ( + "python align_warp_back_multiple_dlib_HR.py --origin_url " + + stage_4_input_image_dir + + " --replace_url " + + stage_4_input_face_dir + + " --save_url " + + stage_4_output_dir + ) + else: + stage_4_command = ( + "python align_warp_back_multiple_dlib.py --origin_url " + + stage_4_input_image_dir + + " --replace_url " + + stage_4_input_face_dir + + " --save_url " + + stage_4_output_dir + ) run_cmd(stage_4_command) print("Finish Stage 4 ...") print("\n")