Skip to content

Commit

Permalink
add resume from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
kwea123 committed May 8, 2020
1 parent afbd375 commit 7a88a66
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 10 deletions.
2 changes: 1 addition & 1 deletion opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_opts():
parser.add_argument('--num_gpus', type=int, default=1,
help='number of gpus')

parser.add_argument('--ckpt_path', type=str, default='',
parser.add_argument('--ckpt_path', type=str, default=None,
help='pretrained checkpoint path to load')
parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'],
help='the prefixes to ignore in the checkpoint state dict')
Expand Down
10 changes: 1 addition & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ def __init__(self, hparams):
self.nerf_fine = NeRF()
self.models += [self.nerf_fine]

# load model if checkpoint path is provided
if self.hparams.ckpt_path != '':
print('Load model from', self.hparams.ckpt_path)
load_ckpt(self.nerf_coarse, self.hparams.ckpt_path,
'nerf_coarse', self.hparams.prefixes_to_ignore)
if hparams.N_importance > 0:
load_ckpt(self.nerf_fine, self.hparams.ckpt_path,
'nerf_fine', self.hparams.prefixes_to_ignore)

def decode_batch(self, batch):
rays = batch['rays'] # (B, 8)
rgbs = batch['rgbs'] # (B, 3)
Expand Down Expand Up @@ -180,6 +171,7 @@ def validation_epoch_end(self, outputs):

trainer = Trainer(max_epochs=hparams.num_epochs,
checkpoint_callback=checkpoint_callback,
resume_from_checkpoint=hparams.ckpt_path,
logger=logger,
early_stop_callback=None,
weights_summary=None,
Expand Down

0 comments on commit 7a88a66

Please sign in to comment.