Skip to content

Further enhance Classification Reference #4444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 21, 2021
28 changes: 21 additions & 7 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,19 @@ def main(args):
sampler=test_sampler, num_workers=args.workers, pin_memory=True)

print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
model.to(device)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

opt_name = args.opt.lower()
if opt_name == 'sgd':
if opt_name.startswith("sgd"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? Are there other options starting with sgd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nasty workaround to support nesterov. Instead of adding another flat args parameter that is useful only when SGD is selected, the user can request for sgd_nesterov.

optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
Expand All @@ -214,15 +216,25 @@ def main(args):
elif args.lr_scheduler == 'cosineannealinglr':
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=args.epochs - args.lr_warmup_epochs)
elif args.lr_scheduler == 'exponentiallr':
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
else:
raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR "
raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported.".format(args.lr_scheduler))

if args.lr_warmup_epochs > 0:
if args.lr_warmup_method == 'linear':
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
total_iters=args.lr_warmup_epochs)
elif args.lr_warmup_method == 'constant':
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
total_iters=args.lr_warmup_epochs)
else:
raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant "
"are supported.")
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
total_iters=args.lr_warmup_epochs), main_lr_scheduler],
schedulers=[warmup_lr_scheduler, main_lr_scheduler],
milestones=[args.lr_warmup_epochs]
)
else:
Expand Down Expand Up @@ -307,7 +319,9 @@ def get_args_parser(add_help=True):
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)')
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
parser.add_argument('--lr-warmup-method', default="constant", type=str,
help='the warmup method (default: constant)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_args_parser(add_help=True):
dest='weight_decay')
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
Expand Down
2 changes: 1 addition & 1 deletion references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def parse_args():
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)')
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr')
parser.add_argument('--lr-warmup-decay', default=0.001, type=float, help='the decay for lr')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
Expand Down