From 139de4aa02c86fe2d434e92adbbf26b86ea8f999 Mon Sep 17 00:00:00 2001 From: Rishik Mourya Date: Mon, 15 Feb 2021 09:53:41 +0530 Subject: [PATCH] updated the model for faster inference --- inference.py | 45 +++++++++++++++++++++++++++++++++++---------- model/modules.py | 15 ++++++++++++++- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/inference.py b/inference.py index 9b089f2..1004f56 100644 --- a/inference.py +++ b/inference.py @@ -17,15 +17,17 @@ parser.add_argument('--watermark-path', type = str, default = './data/watermark-available/watermark.png', help = 'Path to the "watermark" image.') parser.add_argument('--input-depth', type = int, default = 32, help = 'Max channel dimension of the noise input. Set it based on gpu/device memory you have available.') parser.add_argument('--lr', type = float, default = 0.01, help = 'Learning rate.') -parser.add_argument('--training-steps', type = int, default = 4000, help = 'Number of training iterations.') -parser.add_argument('--show-steps', type = int, default = 50, help = 'Interval for visualizing results.') +parser.add_argument('--training-steps', type = int, default = 3000, help = 'Number of training iterations.') +parser.add_argument('--show-steps', type = int, default = 200, help = 'Interval for visualizing results.') parser.add_argument('--reg-noise', type = float, default = 0.03, help = 'Hyper-parameter for regularized noise input.') parser.add_argument('--device', type = str, default = 'cuda', help = 'Device for pytorch, either "cpu" or "cuda".') +parser.add_argument('--max-dim', type = float, default = 512, help = 'Max dimension of the final output image') args = parser.parse_args() if args.device == 'cuda' and not torch.cuda.is_available(): args.device = 'cpu' print('\nSetting device to "cpu", since torch is not built with "cuda" support...') + print('It is recommended to use GPU if possible...') DTYPE = torch.cuda.FloatTensor if args.device == "cuda" else torch.FloatTensor @@ -47,6 +49,7 @@ visualize_sample(image_np, image_mask_np, image_mask_np * image_np, nrow = 3, size_factor = 12) print('Building model...\n') + generator = SkipEncoderDecoder( args.input_depth, num_channels_down = [128] * 5, @@ -65,19 +68,41 @@ print('\nStarting training...\n') -for step in tqdm(range(args.training_steps), desc = 'Completed', ncols = 100): +progress_bar = tqdm(range(TRAINING_STEPS), desc = 'Completed', ncols = 800) + +for step in progress_bar: optimizer.zero_grad() + generator_input = generator_input_saved if args.reg_noise > 0: - generator_input = generator_input_saved + (noise.normal_() * args.reg_noise) + generator_input = generator_input_saved + (noise.normal_() * REG_NOISE) - out = generator(generator_input) + output = generator(generator_input) - loss = objective(out * mask_var, image_var * mask_var) + loss = objective(output * mask_var, watermarked_var * mask_var) loss.backward() - + if step % args.show_steps == 0: - out_np = torch_to_np_array(out) - visualize_sample(np.clip(out_np, 0, 1), nrow = 1, size_factor = 5) + output_image = torch_to_np_array(output) + visualize_sample(watermarked_np, output_image, nrow = 2, size_factor = 10) + + progress_bar.set_postfix(Step = step, Loss = loss.item()) + + optimizer.step() + +# for step in tqdm(range(args.training_steps), desc = 'Completed', ncols = 100): +# optimizer.zero_grad() + +# if args.reg_noise > 0: +# generator_input = generator_input_saved + (noise.normal_() * args.reg_noise) + +# out = generator(generator_input) + +# loss = objective(out * mask_var, image_var * mask_var) +# loss.backward() + +# if step % args.show_steps == 0: +# out_np = torch_to_np_array(out) +# visualize_sample(np.clip(out_np, 0, 1), nrow = 1, size_factor = 5) - optimizer.step() \ No newline at end of file +# optimizer.step() \ No newline at end of file diff --git a/model/modules.py b/model/modules.py index ecb4126..27ba1de 100644 --- a/model/modules.py +++ b/model/modules.py @@ -2,13 +2,26 @@ from torch import nn import numpy as np +class DepthwiseSeperableConv2d(nn.Module): + def __init__(self, input_channels, output_channels, **kwargs): + super(DepthwiseSeperableConv2d, self).__init__() + + self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs) + self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + + return x + class Conv2dBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False): super(Conv2dBlock, self).__init__() self.model = nn.Sequential( nn.ReflectionPad2d(int((kernel_size - 1) / 2)), - nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, bias = bias), + DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2) )