Skip to content

Commit

Permalink
Adjust training code
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 3, 2023
1 parent 57a27c2 commit 122087d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
41 changes: 24 additions & 17 deletions train-supervised.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division, print_function # nopep8
from __future__ import division, print_function
from pathlib import Path # nopep8

import sys # nopep8

Expand Down Expand Up @@ -96,28 +97,29 @@ def fetch_optimizer(args, model, steps):


def train(args):

model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))

if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.load_state_dict(torch.load(args.restore_ckpt),
strict=(not args.allow_nonstrict))

model.cuda()
model.train()

if args.restore_ckpt:
if args.freeze_bn:
model.module.freeze_bn()

train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(
args, model, len(train_loader) * args.num_epochs)
optimizer, scheduler = fetch_optimizer(args, model,
len(train_loader) * args.num_epochs)

scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(args.name)

VAL_FREQ = 5000
add_noise = True
best_evaluation = None

for epoch in range(args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
Expand All @@ -134,8 +136,9 @@ def train(args):

flow_predictions = model(image1, image2, iters=args.iters)

loss, metrics = sequence_loss(
flow_predictions, flow, valid, args.gamma)
loss, metrics = sequence_loss(flow_predictions,
flow, valid,
args.gamma)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
Expand All @@ -147,8 +150,6 @@ def train(args):
logger.push({'loss': loss.item()})

logger.closePbar()
PATH = 'checkpoints/%s_%d.pth' % (args.name, epoch + 1)
torch.save(model.state_dict(), PATH)

results = {}
for val_dataset in args.validation:
Expand All @@ -158,15 +159,20 @@ def train(args):
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))

logger.write_dict(results, 'epoch')

evaluation_score = np.mean(list(results.values()))
if best_evaluation is None or evaluation_score < best_evaluation:
best_evaluation = evaluation_score
PATH = 'checkpoints/%s/model-best.pth' % args.name
torch.save(model.state_dict(), PATH)

model.train()
if args.restore_ckpt:
if args.freeze_bn:
model.module.freeze_bn()

logger.close()
PATH = 'checkpoints/%s.pth' % args.name
PATH = 'checkpoints/%s/model.pth' % args.name
torch.save(model.state_dict(), PATH)

return PATH
Expand All @@ -175,9 +181,11 @@ def train(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument(
'--stage', help="determines which dataset to use for training")
parser.add_argument('--freeze_bn', action='store_true',
help="freeze the batch norm layer")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--allow_nonstrict', action='store_true',
help='allow non-strict loading')
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')

Expand Down Expand Up @@ -205,7 +213,6 @@ def train(args):

cudnn_backend.benchmark = True

if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
os.makedirs(Path(__file__).parent.joinpath('checkpoints', args.name))

train(args)
8 changes: 6 additions & 2 deletions train-supervised.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cmd_transfer="python -u train-supervised.py \
--name raft-sintel-supervised-transfer \
--validation sintel \
--restore_ckpt checkpoints/raft-things.pth \
--freeze_bn \
--gpus 0 \
--num_epochs 100 \
--batch_size 6 \
Expand All @@ -23,5 +24,8 @@ cmd_transfer="python -u train-supervised.py \
--wdecay 0.00001 \
--gamma=0.85"

echo ${cmd_scratch}
eval ${cmd_scratch}
# echo ${cmd_scratch}
# eval ${cmd_scratch}

echo ${cmd_transfer}
eval ${cmd_transfer}

0 comments on commit 122087d

Please sign in to comment.