Skip to content

Commit

Permalink
Save extra information to resume training
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 5, 2023
1 parent 621b99a commit 7d6c2b5
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 41 deletions.
8 changes: 7 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def viz(img, flo):

def demo(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
checkpoint = torch.load(args.model)
weight = checkpoint['model'] if 'model' in checkpoint else checkpoint
model.load_state_dict(weight)

model = model.module
model.to(DEVICE)
Expand Down Expand Up @@ -70,6 +72,10 @@ def demo(args):
action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true',
help='use efficent correlation implementation')
parser.add_argument('--hidden', type=int, default=128,
help='The hidden size of the updater')
parser.add_argument('--context', type=int, default=128,
help='The context size of the updater')
args = parser.parse_args()

os.makedirs('visualization', exist_ok=True)
Expand Down
10 changes: 8 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,22 @@ def validate_kitti(model, iters=24):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--dataset', help="dataset for evaluation")
parser.add_argument('--dataset', default='sintel', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true',
help='use efficent correlation implementation')
parser.add_argument('--hidden', type=int, default=128,
help='The hidden size of the updater')
parser.add_argument('--context', type=int, default=128,
help='The context size of the updater')
args = parser.parse_args()

model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
checkpoint = torch.load(args.model)
weight = checkpoint['model'] if 'model' in checkpoint else checkpoint
model.load_state_dict(weight)

model.cuda()
model.eval()
Expand Down
45 changes: 30 additions & 15 deletions train-selfsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,37 +99,42 @@ def fetch_optimizer(args, model, steps):


def train(args):
train_loader = datasets.fetch_dataloader(args)
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))

optimizer, scheduler = fetch_optimizer(args, model,
len(train_loader) * args.num_epochs)
model.cuda()
model.train()

epoch_start = 0
if args.restore_ckpt is not None:
checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt)
checkpoint = torch.load(args.restore_ckpt)
weight: OrderedDict[str, Any] = checkpoint['model']
if args.reset_context:
weight = OrderedDict()
_weight = OrderedDict()
for key, val in checkpoint.items():
if '.cnet.' not in key:
weight[key] = val
checkpoint = weight
model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict))
_weight[key] = val
weight = _weight
model.load_state_dict(weight, strict=(not args.allow_nonstrict))

model.cuda()
model.train()
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = checkpoint['scheduler']
epoch_start = checkpoint['epoch']

if args.freeze_bn:
model.module.freeze_bn()

train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model,
len(train_loader) * args.num_epochs)

scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(args.name)
logger = Logger(args.name, len(train_loader) * epoch_start)

VAL_FREQ = 5000
add_noise = True
best_evaluation = None

for epoch in range(args.num_epochs):
for epoch in range(epoch_start, args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
for batch_idx, data_blob in enumerate(train_loader):
optimizer.zero_grad()
Expand Down Expand Up @@ -159,7 +164,12 @@ def train(args):

logger.closePbar()
PATH = 'checkpoints/%s/model.pth' % args.name
torch.save(model.state_dict(), PATH)
torch.save({
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
}, PATH)

results = {}
for val_dataset in args.validation:
Expand All @@ -175,7 +185,12 @@ def train(args):
if best_evaluation is None or evaluation_score < best_evaluation:
best_evaluation = evaluation_score
PATH = 'checkpoints/%s/model-best.pth' % args.name
torch.save(model.state_dict(), PATH)
torch.save({
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
}, PATH)

model.train()
if args.freeze_bn:
Expand Down
12 changes: 6 additions & 6 deletions train-selfsupervised.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mkdir -p checkpoints
cmd_scratch="python -u train-selfsupervised.py \
--name raft-sintel-selfsupervised-scratch \
--validation sintel \
--num_epochs 200 \
--num_epochs 250 \
--batch_size 6 \
--lr 0.0004 \
--wdecay 0.00001"
Expand All @@ -14,14 +14,14 @@ cmd_transfer="python -u train-selfsupervised.py \
--validation sintel \
--restore_ckpt checkpoints/raft-things.pth \
--freeze_bn \
--num_epochs 200 \
--num_epochs 250 \
--batch_size 6 \
--lr 0.000125 \
--wdecay 0.00001 \
--gamma=0.85"

# echo ${cmd_scratch}
# eval ${cmd_scratch}
echo ${cmd_scratch}
eval ${cmd_scratch}

echo ${cmd_transfer}
eval ${cmd_transfer}
# echo ${cmd_transfer}
# eval ${cmd_transfer}
46 changes: 31 additions & 15 deletions train-supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,37 +92,43 @@ def fetch_optimizer(args, model, steps):


def train(args):
train_loader = datasets.fetch_dataloader(args)
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))

optimizer, scheduler = fetch_optimizer(args, model,
len(train_loader) * args.num_epochs)

model.cuda()
model.train()

epoch_start = 0
if args.restore_ckpt is not None:
checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt)
checkpoint = torch.load(args.restore_ckpt)
weight: OrderedDict[str, Any] = checkpoint['model']
if args.reset_context:
weight = OrderedDict()
_weight = OrderedDict()
for key, val in checkpoint.items():
if '.cnet.' not in key:
weight[key] = val
checkpoint = weight
model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict))
_weight[key] = val
weight = _weight
model.load_state_dict(weight, strict=(not args.allow_nonstrict))

model.cuda()
model.train()
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = checkpoint['scheduler']
epoch_start = checkpoint['epoch']

if args.freeze_bn:
model.module.freeze_bn()

train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model,
len(train_loader) * args.num_epochs)

scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(args.name)
logger = Logger(args.name, len(train_loader) * epoch_start)

VAL_FREQ = 5000
add_noise = True
best_evaluation = None

for epoch in range(args.num_epochs):
for epoch in range(epoch_start, args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
for batch_idx, data_blob in enumerate(train_loader):
optimizer.zero_grad()
Expand Down Expand Up @@ -152,7 +158,12 @@ def train(args):

logger.closePbar()
PATH = 'checkpoints/%s/model.pth' % args.name
torch.save(model.state_dict(), PATH)
torch.save({
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
}, PATH)

results = {}
for val_dataset in args.validation:
Expand All @@ -168,7 +179,12 @@ def train(args):
if best_evaluation is None or evaluation_score < best_evaluation:
best_evaluation = evaluation_score
PATH = 'checkpoints/%s/model-best.pth' % args.name
torch.save(model.state_dict(), PATH)
torch.save({
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler
}, PATH)

model.train()
if args.freeze_bn:
Expand Down
4 changes: 2 additions & 2 deletions train-supervised.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mkdir -p checkpoints
cmd_scratch="python -u train-supervised.py \
--name raft-sintel-supervised-scratch \
--validation sintel \
--num_epochs 200 \
--num_epochs 250 \
--batch_size 6 \
--lr 0.0004 \
--wdecay 0.00001"
Expand All @@ -14,7 +14,7 @@ cmd_transfer="python -u train-supervised.py \
--validation sintel \
--restore_ckpt checkpoints/raft-things.pth \
--freeze_bn \
--num_epochs 200 \
--num_epochs 250 \
--batch_size 6 \
--lr 0.000125 \
--wdecay 0.00001 \
Expand Down

0 comments on commit 7d6c2b5

Please sign in to comment.