Skip to content

Commit

Permalink
updated api endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
braindotai committed Feb 15, 2021
1 parent c4434cf commit 672fe1b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 99 deletions.
32 changes: 29 additions & 3 deletions helper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torchvision.utils import make_grid


def pil_to_np_array(pil_image):
ar = np.array(pil_image)
if len(ar.shape) == 3:
Expand Down Expand Up @@ -36,9 +36,35 @@ def get_image_grid(images, nrow = 3):

def visualize_sample(*images_np, nrow = 3, size_factor = 10):
c = max(x.shape[0] for x in images_np)
images_np = [x if (x.shape[0] == c) else np.concatenate([x, x, x], axis=0) for x in images_np]
images_np = [x if (x.shape[0] == c) else np.concatenate([x, x, x], axis = 0) for x in images_np]
grid = get_image_grid(images_np, nrow)
plt.figure(figsize = (len(images_np) + size_factor, 12 + size_factor))
plt.axis('off')
plt.imshow(grid.transpose(1, 2, 0))
plt.show()
plt.show()

def max_dimension_resize(image_pil, mask_pil, max_dim):
w, h = image_pil.size
aspect_ratio = w / h
if w > max_dim:
h = int((h / w) * max_dim)
w = max_dim
elif h > max_dim:
w = int((w / h) * max_dim)
h = max_dim
return image_pil.resize((w, h)), mask_pil.resize((w, h))

def preprocess_images(image_path, mask_path, max_dim):
image_pil = read_image(image_path).convert('RGB')
mask_pil = read_image(mask_path).convert('RGB')

image_pil, mask_pil = max_dimension_resize(image_pil, mask_pil, max_dim)

image_np = pil_to_np_array(image_pil)
mask_np = pil_to_np_array(mask_pil)

print('Visualizing mask overlap...')

visualize_sample(image_np, mask_np, image_np * mask_np, nrow = 3, size_factor = 10)

return image_np, mask_np
110 changes: 14 additions & 96 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,26 @@
import torch
from torch import nn, optim
from torchsummary import summary

import os
import numpy as np
from PIL import Image
from tqdm.auto import tqdm

from helper import *
from model.generator import SkipEncoderDecoder, input_noise

import argparse
from api import remove_watermark

parser = argparse.ArgumentParser(description = 'Removing Watermark')
parser.add_argument('--image-path', type = str, default = './data/watermark-available/me.jpg', help = 'Path to the "watermarked" image.')
parser.add_argument('--watermark-path', type = str, default = './data/watermark-available/watermark.png', help = 'Path to the "watermark" image.')
parser.add_argument('--image-path', type = str, default = './data/watermark-unavailable/watermarked/watermarked0.png', help = 'Path to the "watermarked" image.')
parser.add_argument('--mask-path', type = str, default = './data/watermark-unavailable/masks/mask0.png', help = 'Path to the "watermark" image.')
parser.add_argument('--input-depth', type = int, default = 32, help = 'Max channel dimension of the noise input. Set it based on gpu/device memory you have available.')
parser.add_argument('--lr', type = float, default = 0.01, help = 'Learning rate.')
parser.add_argument('--training-steps', type = int, default = 3000, help = 'Number of training iterations.')
parser.add_argument('--show-steps', type = int, default = 200, help = 'Interval for visualizing results.')
parser.add_argument('--show-step', type = int, default = 200, help = 'Interval for visualizing results.')
parser.add_argument('--reg-noise', type = float, default = 0.03, help = 'Hyper-parameter for regularized noise input.')
parser.add_argument('--device', type = str, default = 'cuda', help = 'Device for pytorch, either "cpu" or "cuda".')
parser.add_argument('--max-dim', type = float, default = 512, help = 'Max dimension of the final output image')

args = parser.parse_args()
if args.device == 'cuda' and not torch.cuda.is_available():
args.device = 'cpu'
print('\nSetting device to "cpu", since torch is not built with "cuda" support...')
print('It is recommended to use GPU if possible...')

DTYPE = torch.cuda.FloatTensor if args.device == "cuda" else torch.FloatTensor

image_pil = read_image(args.image_path)
image_pil = image_pil.convert('RGB')
image_pil = image_pil.resize((128, 128))

image_mask_pil = read_image(args.watermark_path)
image_mask_pil = image_mask_pil.convert('RGB')
image_mask_pil = image_mask_pil.resize((image_pil.size[0], image_pil.size[1]))

image_np = pil_to_np_array(image_pil)
image_mask_np = pil_to_np_array(image_mask_pil)
image_mask_np[image_mask_np == 0.0] = 1.0

image_var = np_to_torch_array(image_np).type(DTYPE)
mask_var = np_to_torch_array(image_mask_np).type(DTYPE)

visualize_sample(image_np, image_mask_np, image_mask_np * image_np, nrow = 3, size_factor = 12)

print('Building model...\n')

generator = SkipEncoderDecoder(
args.input_depth,
num_channels_down = [128] * 5,
num_channels_up = [128] * 5,
num_channels_skip = [128] * 5
).type(DTYPE)
generator_input = input_noise(args.input_depth, image_np.shape[1:]).type(DTYPE)
summary(generator, generator_input.shape[1:])

objective = nn.MSELoss()
optimizer = optim.Adam(generator.parameters(), args.lr)

generator_input_saved = generator_input.clone()
noise = generator_input.clone()
generator_input = generator_input_saved

print('\nStarting training...\n')

progress_bar = tqdm(range(TRAINING_STEPS), desc = 'Completed', ncols = 800)

for step in progress_bar:
optimizer.zero_grad()
generator_input = generator_input_saved

if args.reg_noise > 0:
generator_input = generator_input_saved + (noise.normal_() * REG_NOISE)

output = generator(generator_input)

loss = objective(output * mask_var, watermarked_var * mask_var)
loss.backward()

if step % args.show_steps == 0:
output_image = torch_to_np_array(output)
visualize_sample(watermarked_np, output_image, nrow = 2, size_factor = 10)

progress_bar.set_postfix(Step = step, Loss = loss.item())

optimizer.step()

# for step in tqdm(range(args.training_steps), desc = 'Completed', ncols = 100):
# optimizer.zero_grad()

# if args.reg_noise > 0:
# generator_input = generator_input_saved + (noise.normal_() * args.reg_noise)

# out = generator(generator_input)

# loss = objective(out * mask_var, image_var * mask_var)
# loss.backward()

# if step % args.show_steps == 0:
# out_np = torch_to_np_array(out)
# visualize_sample(np.clip(out_np, 0, 1), nrow = 1, size_factor = 5)

# optimizer.step()
remove_watermark(
image_path = args.image_path,
mask_path = args.mask_path,
max_dim = args.max_dim,
show_step = args.show_step,
reg_noise = args.reg_noise,
input_depth = args.input_depth,
lr = args.lr,
training_steps = args.training_steps,
)
Binary file modified model/__pycache__/modules.cpython-37.pyc
Binary file not shown.

0 comments on commit 672fe1b

Please sign in to comment.