Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Revert "add short option for ssd/train.py"
Browse files Browse the repository at this point in the history
This reverts commit 276ae9b.
  • Loading branch information
Hakuyume committed Oct 15, 2018
1 parent dfed5a8 commit 32434fe
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions examples/ssd/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def main():
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--out', default='result')
parser.add_argument('--resume')
parser.add_argument('--short', action='store_true')
args = parser.parse_args()

if args.model == 'ssd300':
Expand Down Expand Up @@ -154,12 +153,10 @@ def main():

updater = training.updaters.StandardUpdater(
train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(
updater, (120000 if not args.short else 12, 'iteration'), args.out)
trainer = training.Trainer(updater, (120000, 'iteration'), args.out)
trainer.extend(
extensions.ExponentialShift('lr', 0.1, init=1e-3),
trigger=triggers.ManualScheduleTrigger(
[80000, 100000] if not args.short else [8, 10], 'iteration'))
trigger=triggers.ManualScheduleTrigger([80000, 100000], 'iteration'))

trainer.extend(
DetectionVOCEvaluator(
Expand All @@ -177,12 +174,10 @@ def main():
trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))

trainer.extend(
extensions.snapshot(),
trigger=(10000 if not args.short else 1, 'iteration'))
trainer.extend(extensions.snapshot(), trigger=(10000, 'iteration'))
trainer.extend(
extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'),
trigger=(120000 if not args.short else 12, 'iteration'))
trigger=(120000, 'iteration'))

if args.resume:
serializers.load_npz(args.resume, trainer)
Expand Down

0 comments on commit 32434fe

Please sign in to comment.