diff --git a/api.py b/api.py index 8948f7e..5370a7c 100644 --- a/api.py +++ b/api.py @@ -1,5 +1,5 @@ from torch import optim - +from tqdm.auto import tqdm from helper import * from model.generator import SkipEncoderDecoder, input_noise @@ -19,7 +19,7 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, num_channels_skip = [128] * 5 ).type(DTYPE) - objective = torch.nn.MSELoss().type(DTYPE) + objective = torch.nn.Loss().type(DTYPE) optimizer = optim.Adam(generator.parameters(), lr) image_var = np_to_torch_array(image_np).type(DTYPE) @@ -30,7 +30,7 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, generator_input_saved = generator_input.detach().clone() noise = generator_input.detach().clone() - print('Starting training...') + print('\nStarting training...\n') progress_bar = tqdm(range(training_steps), desc = 'Completed', ncols = 100) diff --git a/helper.py b/helper.py index cebabad..1336dbb 100644 --- a/helper.py +++ b/helper.py @@ -1,6 +1,5 @@ import numpy as np from PIL import Image -from tqdm import tqdm import matplotlib.pyplot as plt import torch