diff --git a/api.py b/api.py index 5370a7c..c5b9400 100644 --- a/api.py +++ b/api.py @@ -1,5 +1,5 @@ from torch import optim -from tqdm.auto import tqdm + from helper import * from model.generator import SkipEncoderDecoder, input_noise @@ -19,7 +19,7 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, num_channels_skip = [128] * 5 ).type(DTYPE) - objective = torch.nn.Loss().type(DTYPE) + objective = torch.nn.MSELoss().type(DTYPE) optimizer = optim.Adam(generator.parameters(), lr) image_var = np_to_torch_array(image_np).type(DTYPE)