From 672fe1ba76aa30d2b1eff93c6b2e0818e8aa4712 Mon Sep 17 00:00:00 2001 From: Rishik Mourya Date: Mon, 15 Feb 2021 11:01:53 +0530 Subject: [PATCH] updated api endpoint --- helper.py | 32 ++++++- inference.py | 110 +++-------------------- model/__pycache__/modules.cpython-37.pyc | Bin 2291 -> 2927 bytes 3 files changed, 43 insertions(+), 99 deletions(-) diff --git a/helper.py b/helper.py index 0048096..cebabad 100644 --- a/helper.py +++ b/helper.py @@ -1,11 +1,11 @@ import numpy as np from PIL import Image +from tqdm import tqdm import matplotlib.pyplot as plt import torch from torchvision.utils import make_grid - def pil_to_np_array(pil_image): ar = np.array(pil_image) if len(ar.shape) == 3: @@ -36,9 +36,35 @@ def get_image_grid(images, nrow = 3): def visualize_sample(*images_np, nrow = 3, size_factor = 10): c = max(x.shape[0] for x in images_np) - images_np = [x if (x.shape[0] == c) else np.concatenate([x, x, x], axis=0) for x in images_np] + images_np = [x if (x.shape[0] == c) else np.concatenate([x, x, x], axis = 0) for x in images_np] grid = get_image_grid(images_np, nrow) plt.figure(figsize = (len(images_np) + size_factor, 12 + size_factor)) plt.axis('off') plt.imshow(grid.transpose(1, 2, 0)) - plt.show() \ No newline at end of file + plt.show() + +def max_dimension_resize(image_pil, mask_pil, max_dim): + w, h = image_pil.size + aspect_ratio = w / h + if w > max_dim: + h = int((h / w) * max_dim) + w = max_dim + elif h > max_dim: + w = int((w / h) * max_dim) + h = max_dim + return image_pil.resize((w, h)), mask_pil.resize((w, h)) + +def preprocess_images(image_path, mask_path, max_dim): + image_pil = read_image(image_path).convert('RGB') + mask_pil = read_image(mask_path).convert('RGB') + + image_pil, mask_pil = max_dimension_resize(image_pil, mask_pil, max_dim) + + image_np = pil_to_np_array(image_pil) + mask_np = pil_to_np_array(mask_pil) + + print('Visualizing mask overlap...') + + visualize_sample(image_np, mask_np, image_np * mask_np, nrow = 3, size_factor = 10) + + return image_np, mask_np \ No newline at end of file diff --git a/inference.py b/inference.py index 1004f56..418f3be 100644 --- a/inference.py +++ b/inference.py @@ -1,108 +1,26 @@ -import torch -from torch import nn, optim -from torchsummary import summary - -import os -import numpy as np -from PIL import Image -from tqdm.auto import tqdm - -from helper import * -from model.generator import SkipEncoderDecoder, input_noise - import argparse +from api import remove_watermark parser = argparse.ArgumentParser(description = 'Removing Watermark') -parser.add_argument('--image-path', type = str, default = './data/watermark-available/me.jpg', help = 'Path to the "watermarked" image.') -parser.add_argument('--watermark-path', type = str, default = './data/watermark-available/watermark.png', help = 'Path to the "watermark" image.') +parser.add_argument('--image-path', type = str, default = './data/watermark-unavailable/watermarked/watermarked0.png', help = 'Path to the "watermarked" image.') +parser.add_argument('--mask-path', type = str, default = './data/watermark-unavailable/masks/mask0.png', help = 'Path to the "watermark" image.') parser.add_argument('--input-depth', type = int, default = 32, help = 'Max channel dimension of the noise input. Set it based on gpu/device memory you have available.') parser.add_argument('--lr', type = float, default = 0.01, help = 'Learning rate.') parser.add_argument('--training-steps', type = int, default = 3000, help = 'Number of training iterations.') -parser.add_argument('--show-steps', type = int, default = 200, help = 'Interval for visualizing results.') +parser.add_argument('--show-step', type = int, default = 200, help = 'Interval for visualizing results.') parser.add_argument('--reg-noise', type = float, default = 0.03, help = 'Hyper-parameter for regularized noise input.') parser.add_argument('--device', type = str, default = 'cuda', help = 'Device for pytorch, either "cpu" or "cuda".') parser.add_argument('--max-dim', type = float, default = 512, help = 'Max dimension of the final output image') args = parser.parse_args() -if args.device == 'cuda' and not torch.cuda.is_available(): - args.device = 'cpu' - print('\nSetting device to "cpu", since torch is not built with "cuda" support...') - print('It is recommended to use GPU if possible...') - -DTYPE = torch.cuda.FloatTensor if args.device == "cuda" else torch.FloatTensor - -image_pil = read_image(args.image_path) -image_pil = image_pil.convert('RGB') -image_pil = image_pil.resize((128, 128)) - -image_mask_pil = read_image(args.watermark_path) -image_mask_pil = image_mask_pil.convert('RGB') -image_mask_pil = image_mask_pil.resize((image_pil.size[0], image_pil.size[1])) - -image_np = pil_to_np_array(image_pil) -image_mask_np = pil_to_np_array(image_mask_pil) -image_mask_np[image_mask_np == 0.0] = 1.0 - -image_var = np_to_torch_array(image_np).type(DTYPE) -mask_var = np_to_torch_array(image_mask_np).type(DTYPE) - -visualize_sample(image_np, image_mask_np, image_mask_np * image_np, nrow = 3, size_factor = 12) - -print('Building model...\n') - -generator = SkipEncoderDecoder( - args.input_depth, - num_channels_down = [128] * 5, - num_channels_up = [128] * 5, - num_channels_skip = [128] * 5 -).type(DTYPE) -generator_input = input_noise(args.input_depth, image_np.shape[1:]).type(DTYPE) -summary(generator, generator_input.shape[1:]) - -objective = nn.MSELoss() -optimizer = optim.Adam(generator.parameters(), args.lr) - -generator_input_saved = generator_input.clone() -noise = generator_input.clone() -generator_input = generator_input_saved - -print('\nStarting training...\n') - -progress_bar = tqdm(range(TRAINING_STEPS), desc = 'Completed', ncols = 800) - -for step in progress_bar: - optimizer.zero_grad() - generator_input = generator_input_saved - - if args.reg_noise > 0: - generator_input = generator_input_saved + (noise.normal_() * REG_NOISE) - - output = generator(generator_input) - - loss = objective(output * mask_var, watermarked_var * mask_var) - loss.backward() - - if step % args.show_steps == 0: - output_image = torch_to_np_array(output) - visualize_sample(watermarked_np, output_image, nrow = 2, size_factor = 10) - - progress_bar.set_postfix(Step = step, Loss = loss.item()) - - optimizer.step() - -# for step in tqdm(range(args.training_steps), desc = 'Completed', ncols = 100): -# optimizer.zero_grad() -# if args.reg_noise > 0: -# generator_input = generator_input_saved + (noise.normal_() * args.reg_noise) - -# out = generator(generator_input) - -# loss = objective(out * mask_var, image_var * mask_var) -# loss.backward() - -# if step % args.show_steps == 0: -# out_np = torch_to_np_array(out) -# visualize_sample(np.clip(out_np, 0, 1), nrow = 1, size_factor = 5) - -# optimizer.step() \ No newline at end of file +remove_watermark( + image_path = args.image_path, + mask_path = args.mask_path, + max_dim = args.max_dim, + show_step = args.show_step, + reg_noise = args.reg_noise, + input_depth = args.input_depth, + lr = args.lr, + training_steps = args.training_steps, +) \ No newline at end of file diff --git a/model/__pycache__/modules.cpython-37.pyc b/model/__pycache__/modules.cpython-37.pyc index 307883f731354df52ab01a63063554400914938e..4d8d21c20837c75e524c7d4b452c632ab3444149 100644 GIT binary patch delta 1394 zcmaJ=y>App6yKTsygT~<=i=A~$2Jfyo5TmU5=DxLgKz*vG6oqzK@gkecI;ccce|R| z3)s5Ig(HO`iFBGGg35@B64CJ=P(aU=RG~_Nl!7Yn&7Copa7NlUvor7G_kQoa*9ZO? zE}tG8EE=%({pf1IJ*xB%v(0&d?Zseq6Y|<@y z-ylYA-n2TTLzZa-GA14KFZw6E2X_C3Ml>@2FyY3D42Wqheyp4op0+dbihw(xsM0Fq zEtUi#UT3~5o8(a7^iZfmw2e;{vNUex2-UZgC0Meo%5f`e=Rv&@w!mJfeQ7)2b>Xbu z3i*nF7=<9HC!r8Q@X&a(ASdX|50lfA7uWKX*BjTGX~Ytg?SzS1w#Gd`=3vGs3{s|J zWPNHcuH+p$u`vRA9b?Prko)8pAYcpV&?l?{S@bE#s&S+d&I7P1a+@E(J1vIr#$8T1 zAdc|4AG%12bbWHK(6zpSFN@}}Gl-g3(9vK&-8xH>AlM1x6TTmnFdD}aV`&A}*dqgu zzUg)@P;Pghjgh79-=E0}Ier$EUQrNSSIpwbigHBGeY+VT7(fr64I zAE9{5C&%jUm~W%dEMk6m=(AWZ@=?jnMR>i zUygwibgL~IVa9x_Xzj%NpuZ`eOk$DO(`Gg|Eq$l)Z4C`6KaN4(m+J|5KplZ)bPSGx z_zVchr7E90=MGizXQ0=awjM0y=uPlxIpbchXuvzDk-bml1D740?YYqj?p9K65G1S> z1atBWZ+v4{c|FUz+j<=c d>a*{Z@SUOaPWA?W4x3xxltKjJ(V@Yi@_*9dC1C&n delta 793 zcmZuvO=}ZD7|zV@&Ti6VNo$B{l4>Gix(!6DUW8h(@#CZ(^i;$lBvTsQY{E=ZKQ8_Q zAr6QqK?K2@;7vSv_2PN(5A@`rcqs_J??kEr7v5)fW}f$%_v=UTU@_dv(Snk1S#x9L97ADl9|tG_yH)OnzPI%o4PaP}#HBmfcB zSKPAhkg6WJ^}w+9Fxflb(=XjRGp-Wxj{fXcnr7jlkco)7>OU2Rsj=+$g!EB>m^A0z zxT6&3?^B0c`a4-)y*C*4M2t9zg=&qS$TV;nn92gANJw2$AO#(I?M5CBA}SYQG32p4 z4(kj`c;fF2<)e=5<@KI-%ERcd>jT=jhJvz;jNG61yCRM`|LY$SSws@bNC&_JQ>YM- zFk?@q&xto2)0ZWUCM~l&6$i%aGqp*&SQ~qzOLV7R-RRWD60r7 z2uIbHmvDOm;mCik&{)P}U<$h(kgOjRZD}K4UO+EQ%6Q3S07KEje|FPOiS5U%2_W43 YN8UEj1RJwy+QhL{*f9_%0bL4z18B&d3jhEB