From 59cf7ea621d12ae5eeca3c170fff1db21760490d Mon Sep 17 00:00:00 2001 From: Michael Wu Date: Tue, 26 Mar 2019 15:19:35 -0700 Subject: [PATCH] Add warmup scheduler to pytext (#423) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/423 Warmup scheduler that linearly increases learning rate from 0 to final value, across a specified number of training steps. Differential Revision: D14609360 fbshipit-source-id: b3082af64fe6b2ebb1ac83a0a4fbf0c381a56d24 --- pytext/optimizer/scheduler.py | 45 +++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/pytext/optimizer/scheduler.py b/pytext/optimizer/scheduler.py index 752ed6d6c..ee07a663b 100644 --- a/pytext/optimizer/scheduler.py +++ b/pytext/optimizer/scheduler.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import math +from typing import Optional import torch from pytext.config import ConfigBase @@ -300,3 +301,47 @@ def from_config(cls, config: Config, optimizer: Optimizer): def step_epoch(self, metrics=None, epoch=None): self.step(epoch) + + +class WarmupScheduler(_LRScheduler, BatchScheduler): + """ + Scheduler to linearly increase learning rate from 0 to final value at the beginning + of training. + """ + + class Config(BatchScheduler.Config): + #: number of training steps over which to increase learning rate + warmup_steps: Optional[int] = 10000 + #: number of epochs (can be fractional) over which to increase learning rate. + #: Mutually exclusive with `warmup_steps`. + warmup_epochs: Optional[float] = None + + @classmethod + def from_config(cls, config: Config, optimizer: Optimizer): + return cls(optimizer, config.warmup_steps, config.warmup_epochs) + + def __init__(self, optimizer, warmup_steps, warmup_epochs): + assert (warmup_steps is None) ^ (warmup_epochs is None) + self.warmup_steps = warmup_steps + self.warmup_epochs = warmup_epochs + self.current_steps = 0 + super().__init__(optimizer) + + def prepare(self, train_iter, total_epochs): + super().prepare(train_iter, total_epochs) + if self.warmup_epochs is not None: + self.warmup_steps = int(self.warmup_epochs * self.steps_per_epoch) + # Initialize first LR. In epoch-based schedulers this happens in __init__(), + # but in this case this happens after prepare(). + self.step_batch() + + def step_batch(self): + self.current_steps += 1 + self.step() + + def get_lr(self): + if self.warmup_steps is None or self.current_steps >= self.warmup_steps: + lr_multiplier = 1.0 + else: + lr_multiplier = self.current_steps / self.warmup_steps + return [lr_multiplier * base_lr for base_lr in self.base_lrs]