Skip to content

Commit

Permalink
Add best_evaluation in stored checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 7, 2023
1 parent 41bbb35 commit 0c7a1e2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
18 changes: 11 additions & 7 deletions train-selfsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def train(args):
model.train()

epoch_start = 0
best_evaluation = None
if args.restore_ckpt is not None:
checkpoint = torch.load(args.restore_ckpt)
weight: OrderedDict[str, Any] = checkpoint['model']
weight: OrderedDict[str, Any] = checkpoint['model'] if 'model' in checkpoint else checkpoint
if args.reset_context:
_weight = OrderedDict()
for key, val in checkpoint.items():
Expand All @@ -120,9 +121,11 @@ def train(args):
weight = _weight
model.load_state_dict(weight, strict=(not args.allow_nonstrict))

optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = checkpoint['scheduler']
epoch_start = checkpoint['epoch']
if 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = checkpoint['scheduler']
epoch_start = checkpoint['epoch']
best_evaluation = checkpoint.get('best_evaluation', None)

if args.freeze_bn:
model.module.freeze_bn()
Expand All @@ -132,7 +135,6 @@ def train(args):

VAL_FREQ = 5000
add_noise = True
best_evaluation = None

for epoch in range(epoch_start, args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
Expand Down Expand Up @@ -168,7 +170,8 @@ def train(args):
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
'scheduler': scheduler,
'best_evaluation': best_evaluation
}, PATH)

results = {}
Expand All @@ -189,7 +192,8 @@ def train(args):
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
'scheduler': scheduler,
'best_evaluation': best_evaluation
}, PATH)

model.train()
Expand Down
18 changes: 11 additions & 7 deletions train-supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def train(args):
model.train()

epoch_start = 0
best_evaluation = None
if args.restore_ckpt is not None:
checkpoint = torch.load(args.restore_ckpt)
weight: OrderedDict[str, Any] = checkpoint['model']
weight: OrderedDict[str, Any] = checkpoint['model'] if 'model' in checkpoint else checkpoint
if args.reset_context:
_weight = OrderedDict()
for key, val in checkpoint.items():
Expand All @@ -114,9 +115,11 @@ def train(args):
weight = _weight
model.load_state_dict(weight, strict=(not args.allow_nonstrict))

optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = checkpoint['scheduler']
epoch_start = checkpoint['epoch']
if 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = checkpoint['scheduler']
epoch_start = checkpoint['epoch']
best_evaluation = checkpoint.get('best_evaluation', None)

if args.freeze_bn:
model.module.freeze_bn()
Expand All @@ -126,7 +129,6 @@ def train(args):

VAL_FREQ = 5000
add_noise = True
best_evaluation = None

for epoch in range(epoch_start, args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
Expand Down Expand Up @@ -162,7 +164,8 @@ def train(args):
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
'scheduler': scheduler,
'best_evaluation': best_evaluation
}, PATH)

results = {}
Expand All @@ -183,7 +186,8 @@ def train(args):
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
'scheduler': scheduler,
'best_evaluation': best_evaluation
}, PATH)

model.train()
Expand Down
2 changes: 2 additions & 0 deletions train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cmd_supervised_transfer="python -u train-supervised.py \
--lr 0.0002 \
--wdecay 0.00001 \
--gamma=0.85 \
--allow_nonstrict \
--reset_context \
--context 128"

Expand All @@ -39,6 +40,7 @@ cmd_selfsupervised_transfer="python -u train-selfsupervised.py \
--lr 0.0002 \
--wdecay 0.00001 \
--gamma=0.85 \
--allow_nonstrict \
--reset_context \
--context 128"

Expand Down

0 comments on commit 0c7a1e2

Please sign in to comment.