diff --git a/main.py b/main.py index 59b4679..9219689 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,9 @@ import argparse +from PIL import Image import cv2 import numpy as np +from preprocess_image import preprocess_image import tensorflow as tf import neuralgym as ng @@ -12,8 +14,10 @@ help='The filename of image to be completed.') parser.add_argument('--output', default='output.png', type=str, help='Where to write output.') +parser.add_argument('--watermark_type', default='istock', type=str, + help='The watermark type') parser.add_argument('--checkpoint_dir', default='model/', type=str, - help='The directory of tensorflow checkpoint.') + help='The directory of tensorflow checkpoint.') #checkpoint_dir = 'model/' @@ -24,40 +28,30 @@ args, unknown = parser.parse_known_args() model = InpaintCAModel() - image = cv2.imread(args.image) - mask = cv2.imread('assets/mask.png') - # mask = cv2.resize(mask, (0,0), fx=0.5, fy=0.5) - - assert image.shape == mask.shape - - h, w, _ = image.shape - grid = 8 - image = image[:h//grid*grid, :w//grid*grid, :] - mask = mask[:h//grid*grid, :w//grid*grid, :] - print('Shape of image: {}'.format(image.shape)) - - image = np.expand_dims(image, 0) - mask = np.expand_dims(mask, 0) - input_image = np.concatenate([image, mask], axis=2) + image = Image.open(args.image) + input_image = preprocess_image(image, args.watermark_type) + tf.reset_default_graph() sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True - with tf.Session(config=sess_config) as sess: - input_image = tf.constant(input_image, dtype=tf.float32) - output = model.build_server_graph(FLAGS, input_image) - output = (output + 1.) * 127.5 - output = tf.reverse(output, [-1]) - output = tf.saturate_cast(output, tf.uint8) - # load pretrained model - vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) - assign_ops = [] - for var in vars_list: - vname = var.name - from_name = vname - var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name) - assign_ops.append(tf.assign(var, var_value)) - sess.run(assign_ops) - print('Model loaded.') - result = sess.run(output) - cv2.imwrite(args.output, result[0][:, :, ::-1]) - print('image saved to {}'.format(args.output)) + if (input_image != None): + with tf.Session(config=sess_config) as sess: + input_image = tf.constant(input_image, dtype=tf.float32) + output = model.build_server_graph(FLAGS, input_image) + output = (output + 1.) * 127.5 + output = tf.reverse(output, [-1]) + output = tf.saturate_cast(output, tf.uint8) + # load pretrained model + vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + assign_ops = [] + for var in vars_list: + vname = var.name + from_name = vname + var_value = tf.contrib.framework.load_variable( + args.checkpoint_dir, from_name) + assign_ops.append(tf.assign(var, var_value)) + sess.run(assign_ops) + print('Model loaded.') + result = sess.run(output) + cv2.imwrite(args.output, result[0][:, :, ::-1]) + print('image saved to {}'.format(args.output)) diff --git a/preprocess_image.py b/preprocess_image.py new file mode 100644 index 0000000..9905cea --- /dev/null +++ b/preprocess_image.py @@ -0,0 +1,52 @@ +from typing import Any +import numpy as np +from PIL import Image +import cv2 + + +def preprocess_image(image: Any, watermark_type: str) -> Any | None: + image_type: str = '' + preprocessed_mask_image = None + if image.mode != "RGB": + image = image.convert("RGB") + image = np.array(image) + image_h = image.shape[0] + image_w = image.shape[1] + aspectRatioImage = image_w / image_h + print("image size: {}".format(image.shape)) + + if image_w > image_h: + image_type = "landscape" + elif image_w == image_h: + image_type = "landscape" + else: + image_type = "potrait" + + mask_image = Image.open( + "utils/{}/{}/mask.png".format(watermark_type, image_type)) + mask_image = np.array(mask_image) + print("mask image size: {}".format(mask_image.shape)) + + aspectRatioMaskImage = mask_image.shape[1] / mask_image.shape[0] + upperBoundAspectRatio = 1.05 * aspectRatioMaskImage + lowerBoundAspectRatio = 0.95 * aspectRatioMaskImage + + if aspectRatioImage >= lowerBoundAspectRatio and aspectRatioImage <= upperBoundAspectRatio: + preprocessed_mask_image = cv2.resize(mask_image, (image_w, image_h)) + print(preprocessed_mask_image.shape) + else: + print("Image size not supported!!!") + + if preprocessed_mask_image != None: + assert image.shape == preprocessed_mask_image + grid = 8 + image = image[:image_h//grid*grid, :image_w//grid*grid, :] + preprocessed_mask_image = preprocessed_mask_image[:image_h // + grid*grid, :image_w//grid*grid, :] + image = np.expand_dims(image, 0) + preprocessed_mask_image = np.expand_dims(preprocessed_mask_image, 0) + input_image = np.concatenate([image, preprocessed_mask_image], axis=2) + return input_image + + else: + return preprocessed_mask_image diff --git a/utils/istock/landscape/mask.png b/utils/istock/landscape/mask.png new file mode 100644 index 0000000..d793e98 Binary files /dev/null and b/utils/istock/landscape/mask.png differ