Skip to content

Commit

Permalink
added load weights to cyclegan
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 8, 2020
1 parent 40f9c4d commit e7dbefb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
13 changes: 12 additions & 1 deletion cyclegan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@
pool_A = ReplayBuffer()
pool_B = ReplayBuffer()

if args.continue_train:
args.start_epoch = load_weights(state_dict_path=args.check_point,
models=[D_A, D_B, G_AB, G_BA],
model_names=['D_A', 'D_B', 'G_AB', 'G_BA'],
optimizers=[optimizer_G, optimizer_D],
optimizer_names=['optimizer_G', 'optimizer_D'],
return_val='start_epoch')

pbar = tqdm(range(start_epoch, args.epochs))

pbar = tqdm(
range(args.starting_epoch, args.n_epochs),
total=(args.n_epochs - args.starting_epoch)
Expand Down Expand Up @@ -220,7 +230,8 @@
'optimizer_G': optimizer_G.state_dict(),
'D_A': D_A.state_dict(),
'D_B': D_B.state_dict(),
'optimizer_D': optimizer_D.state_dict()
'optimizer_D': optimizer_D.state_dict(),
'start_epoch': epoch + 1
}, f"{args.checkpoint_dir}/{args.data_root.split('/')[-1]}_{epoch}.pth")

# saving space, only saving latest weights
Expand Down
8 changes: 4 additions & 4 deletions simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@
start_epoch = 0
if args.continue_train:
start_epoch = load_weights(state_dict_path=args.check_point,
models=[f_q, f_k],
model_names=['f_q', 'f_k'],
optimizers=[optimizer],
optimizer_names=['optimizer'],
models=model,
model_names='model',
optimizers=optimizer,
optimizer_names='optimizer',
return_val='start_epoch')

pbar = tqdm(range(start_epoch, args.epochs))
Expand Down

0 comments on commit e7dbefb

Please sign in to comment.