Skip to content

Commit

Permalink
fixed notebook tqdm progress
Browse files Browse the repository at this point in the history
  • Loading branch information
braindotai committed Feb 15, 2021
1 parent 6320743 commit a930c49
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
4 changes: 2 additions & 2 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from helper import *
from model.generator import SkipEncoderDecoder, input_noise

def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, show_step, training_steps, tqdm = tqdm):
def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, show_step, training_steps, tqdm_length = 100):
DTYPE = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
if not torch.cuda.is_available():
print('\nSetting device to "cpu", since torch is not built with "cuda" support...')
Expand Down Expand Up @@ -32,7 +32,7 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr,

print('\nStarting training...\n')

progress_bar = tqdm(range(training_steps), desc = 'Completed', ncols = 100)
progress_bar = tqdm(range(training_steps), desc = 'Completed', ncols = tqdm_length)

for step in progress_bar:
optimizer.zero_grad()
Expand Down
1 change: 0 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
parser.add_argument('--training-steps', type = int, default = 3000, help = 'Number of training iterations.')
parser.add_argument('--show-step', type = int, default = 200, help = 'Interval for visualizing results.')
parser.add_argument('--reg-noise', type = float, default = 0.03, help = 'Hyper-parameter for regularized noise input.')
parser.add_argument('--device', type = str, default = 'cuda', help = 'Device for pytorch, either "cpu" or "cuda".')
parser.add_argument('--max-dim', type = float, default = 512, help = 'Max dimension of the final output image')

args = parser.parse_args()
Expand Down

0 comments on commit a930c49

Please sign in to comment.