Skip to content

Commit

Permalink
updated the model for faster inference
Browse files Browse the repository at this point in the history
  • Loading branch information
braindotai committed Feb 15, 2021
1 parent a3e38cf commit 139de4a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
45 changes: 35 additions & 10 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
parser.add_argument('--watermark-path', type = str, default = './data/watermark-available/watermark.png', help = 'Path to the "watermark" image.')
parser.add_argument('--input-depth', type = int, default = 32, help = 'Max channel dimension of the noise input. Set it based on gpu/device memory you have available.')
parser.add_argument('--lr', type = float, default = 0.01, help = 'Learning rate.')
parser.add_argument('--training-steps', type = int, default = 4000, help = 'Number of training iterations.')
parser.add_argument('--show-steps', type = int, default = 50, help = 'Interval for visualizing results.')
parser.add_argument('--training-steps', type = int, default = 3000, help = 'Number of training iterations.')
parser.add_argument('--show-steps', type = int, default = 200, help = 'Interval for visualizing results.')
parser.add_argument('--reg-noise', type = float, default = 0.03, help = 'Hyper-parameter for regularized noise input.')
parser.add_argument('--device', type = str, default = 'cuda', help = 'Device for pytorch, either "cpu" or "cuda".')
parser.add_argument('--max-dim', type = float, default = 512, help = 'Max dimension of the final output image')

args = parser.parse_args()
if args.device == 'cuda' and not torch.cuda.is_available():
args.device = 'cpu'
print('\nSetting device to "cpu", since torch is not built with "cuda" support...')
print('It is recommended to use GPU if possible...')

DTYPE = torch.cuda.FloatTensor if args.device == "cuda" else torch.FloatTensor

Expand All @@ -47,6 +49,7 @@
visualize_sample(image_np, image_mask_np, image_mask_np * image_np, nrow = 3, size_factor = 12)

print('Building model...\n')

generator = SkipEncoderDecoder(
args.input_depth,
num_channels_down = [128] * 5,
Expand All @@ -65,19 +68,41 @@

print('\nStarting training...\n')

for step in tqdm(range(args.training_steps), desc = 'Completed', ncols = 100):
progress_bar = tqdm(range(TRAINING_STEPS), desc = 'Completed', ncols = 800)

for step in progress_bar:
optimizer.zero_grad()
generator_input = generator_input_saved

if args.reg_noise > 0:
generator_input = generator_input_saved + (noise.normal_() * args.reg_noise)
generator_input = generator_input_saved + (noise.normal_() * REG_NOISE)

out = generator(generator_input)
output = generator(generator_input)

loss = objective(out * mask_var, image_var * mask_var)
loss = objective(output * mask_var, watermarked_var * mask_var)
loss.backward()

if step % args.show_steps == 0:
out_np = torch_to_np_array(out)
visualize_sample(np.clip(out_np, 0, 1), nrow = 1, size_factor = 5)
output_image = torch_to_np_array(output)
visualize_sample(watermarked_np, output_image, nrow = 2, size_factor = 10)

progress_bar.set_postfix(Step = step, Loss = loss.item())

optimizer.step()

# for step in tqdm(range(args.training_steps), desc = 'Completed', ncols = 100):
# optimizer.zero_grad()

# if args.reg_noise > 0:
# generator_input = generator_input_saved + (noise.normal_() * args.reg_noise)

# out = generator(generator_input)

# loss = objective(out * mask_var, image_var * mask_var)
# loss.backward()

# if step % args.show_steps == 0:
# out_np = torch_to_np_array(out)
# visualize_sample(np.clip(out_np, 0, 1), nrow = 1, size_factor = 5)

optimizer.step()
# optimizer.step()
15 changes: 14 additions & 1 deletion model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@
from torch import nn
import numpy as np

class DepthwiseSeperableConv2d(nn.Module):
def __init__(self, input_channels, output_channels, **kwargs):
super(DepthwiseSeperableConv2d, self).__init__()

self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs)
self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1)

def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)

return x

class Conv2dBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False):
super(Conv2dBlock, self).__init__()

self.model = nn.Sequential(
nn.ReflectionPad2d(int((kernel_size - 1) / 2)),
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, bias = bias),
DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
Expand Down

0 comments on commit 139de4a

Please sign in to comment.