Skip to content

Commit

Permalink
add preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
zuruoke committed Jan 22, 2022
1 parent 5987d29 commit 170c84c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 35 deletions.
64 changes: 29 additions & 35 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import argparse

from PIL import Image
import cv2
import numpy as np
from preprocess_image import preprocess_image
import tensorflow as tf
import neuralgym as ng

Expand All @@ -12,8 +14,10 @@
help='The filename of image to be completed.')
parser.add_argument('--output', default='output.png', type=str,
help='Where to write output.')
parser.add_argument('--watermark_type', default='istock', type=str,
help='The watermark type')
parser.add_argument('--checkpoint_dir', default='model/', type=str,
help='The directory of tensorflow checkpoint.')
help='The directory of tensorflow checkpoint.')

#checkpoint_dir = 'model/'

Expand All @@ -24,40 +28,30 @@
args, unknown = parser.parse_known_args()

model = InpaintCAModel()
image = cv2.imread(args.image)
mask = cv2.imread('assets/mask.png')
# mask = cv2.resize(mask, (0,0), fx=0.5, fy=0.5)

assert image.shape == mask.shape

h, w, _ = image.shape
grid = 8
image = image[:h//grid*grid, :w//grid*grid, :]
mask = mask[:h//grid*grid, :w//grid*grid, :]
print('Shape of image: {}'.format(image.shape))

image = np.expand_dims(image, 0)
mask = np.expand_dims(mask, 0)
input_image = np.concatenate([image, mask], axis=2)
image = Image.open(args.image)
input_image = preprocess_image(image, args.watermark_type)
tf.reset_default_graph()

sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
with tf.Session(config=sess_config) as sess:
input_image = tf.constant(input_image, dtype=tf.float32)
output = model.build_server_graph(FLAGS, input_image)
output = (output + 1.) * 127.5
output = tf.reverse(output, [-1])
output = tf.saturate_cast(output, tf.uint8)
# load pretrained model
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = []
for var in vars_list:
vname = var.name
from_name = vname
var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name)
assign_ops.append(tf.assign(var, var_value))
sess.run(assign_ops)
print('Model loaded.')
result = sess.run(output)
cv2.imwrite(args.output, result[0][:, :, ::-1])
print('image saved to {}'.format(args.output))
if (input_image != None):
with tf.Session(config=sess_config) as sess:
input_image = tf.constant(input_image, dtype=tf.float32)
output = model.build_server_graph(FLAGS, input_image)
output = (output + 1.) * 127.5
output = tf.reverse(output, [-1])
output = tf.saturate_cast(output, tf.uint8)
# load pretrained model
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = []
for var in vars_list:
vname = var.name
from_name = vname
var_value = tf.contrib.framework.load_variable(
args.checkpoint_dir, from_name)
assign_ops.append(tf.assign(var, var_value))
sess.run(assign_ops)
print('Model loaded.')
result = sess.run(output)
cv2.imwrite(args.output, result[0][:, :, ::-1])
print('image saved to {}'.format(args.output))
52 changes: 52 additions & 0 deletions preprocess_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any
import numpy as np
from PIL import Image
import cv2


def preprocess_image(image: Any, watermark_type: str) -> Any | None:
image_type: str = ''
preprocessed_mask_image = None
if image.mode != "RGB":
image = image.convert("RGB")
image = np.array(image)
image_h = image.shape[0]
image_w = image.shape[1]
aspectRatioImage = image_w / image_h
print("image size: {}".format(image.shape))

if image_w > image_h:
image_type = "landscape"
elif image_w == image_h:
image_type = "landscape"
else:
image_type = "potrait"

mask_image = Image.open(
"utils/{}/{}/mask.png".format(watermark_type, image_type))
mask_image = np.array(mask_image)
print("mask image size: {}".format(mask_image.shape))

aspectRatioMaskImage = mask_image.shape[1] / mask_image.shape[0]
upperBoundAspectRatio = 1.05 * aspectRatioMaskImage
lowerBoundAspectRatio = 0.95 * aspectRatioMaskImage

if aspectRatioImage >= lowerBoundAspectRatio and aspectRatioImage <= upperBoundAspectRatio:
preprocessed_mask_image = cv2.resize(mask_image, (image_w, image_h))
print(preprocessed_mask_image.shape)
else:
print("Image size not supported!!!")

if preprocessed_mask_image != None:
assert image.shape == preprocessed_mask_image
grid = 8
image = image[:image_h//grid*grid, :image_w//grid*grid, :]
preprocessed_mask_image = preprocessed_mask_image[:image_h //
grid*grid, :image_w//grid*grid, :]
image = np.expand_dims(image, 0)
preprocessed_mask_image = np.expand_dims(preprocessed_mask_image, 0)
input_image = np.concatenate([image, preprocessed_mask_image], axis=2)
return input_image

else:
return preprocessed_mask_image
Binary file added utils/istock/landscape/mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 170c84c

Please sign in to comment.