From e7dbefb26dc3e9bf1544c2eef1960f8a9dd7834e Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Tue, 8 Dec 2020 17:31:11 +0800 Subject: [PATCH] added load weights to cyclegan --- cyclegan/train.py | 13 ++++++++++++- simsiam/main.py | 8 ++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cyclegan/train.py b/cyclegan/train.py index 1bcc268..85770a2 100644 --- a/cyclegan/train.py +++ b/cyclegan/train.py @@ -90,6 +90,16 @@ pool_A = ReplayBuffer() pool_B = ReplayBuffer() + if args.continue_train: + args.start_epoch = load_weights(state_dict_path=args.check_point, + models=[D_A, D_B, G_AB, G_BA], + model_names=['D_A', 'D_B', 'G_AB', 'G_BA'], + optimizers=[optimizer_G, optimizer_D], + optimizer_names=['optimizer_G', 'optimizer_D'], + return_val='start_epoch') + + pbar = tqdm(range(start_epoch, args.epochs)) + pbar = tqdm( range(args.starting_epoch, args.n_epochs), total=(args.n_epochs - args.starting_epoch) @@ -220,7 +230,8 @@ 'optimizer_G': optimizer_G.state_dict(), 'D_A': D_A.state_dict(), 'D_B': D_B.state_dict(), - 'optimizer_D': optimizer_D.state_dict() + 'optimizer_D': optimizer_D.state_dict(), + 'start_epoch': epoch + 1 }, f"{args.checkpoint_dir}/{args.data_root.split('/')[-1]}_{epoch}.pth") # saving space, only saving latest weights diff --git a/simsiam/main.py b/simsiam/main.py index e636cee..52080ee 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -83,10 +83,10 @@ start_epoch = 0 if args.continue_train: start_epoch = load_weights(state_dict_path=args.check_point, - models=[f_q, f_k], - model_names=['f_q', 'f_k'], - optimizers=[optimizer], - optimizer_names=['optimizer'], + models=model, + model_names='model', + optimizers=optimizer, + optimizer_names='optimizer', return_val='start_epoch') pbar = tqdm(range(start_epoch, args.epochs))