Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add warmup scheduler to pytext
Browse files Browse the repository at this point in the history
Summary: Warmup scheduler that linearly increases learning rate from 0 to final value, across a specified number of training steps.

Differential Revision: D14609360

fbshipit-source-id: fa6e7068807cbccb8e2d2de27de7fc44cd8d356c
  • Loading branch information
Michael Wu authored and facebook-github-bot committed Mar 26, 2019
1 parent c669f8a commit 4d80fae
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions pytext/optimizer/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
from typing import Optional

import torch
from pytext.config import ConfigBase
Expand Down Expand Up @@ -300,3 +301,39 @@ def from_config(cls, config: Config, optimizer: Optimizer):

def step_epoch(self, metrics=None, epoch=None):
self.step(epoch)


class WarmupScheduler(_LRScheduler, BatchScheduler):
class Config(BatchScheduler.Config):
warmup_steps: Optional[int] = 10000
warmup_epochs: Optional[float] = None

@classmethod
def from_config(cls, config: Config, optimizer: Optimizer):
return cls(optimizer, config.warmup_steps, config.warmup_epochs)

def __init__(self, optimizer, warmup_steps, warmup_epochs):
assert (warmup_steps is None) ^ (warmup_epochs is None)
self.warmup_steps = warmup_steps
self.warmup_epochs = warmup_epochs
self.current_steps = 0
super().__init__(optimizer)

def prepare(self, train_iter, total_epochs):
super().prepare(train_iter, total_epochs)
if self.warmup_epochs is not None:
self.warmup_steps = int(self.warmup_epochs * self.steps_per_epoch)
# Initialize first LR. In epoch-based schedulers this happens in __init__(),
# but in this case this happens after prepare().
self.step_batch()

def step_batch(self):
self.current_steps += 1
self.step()

def get_lr(self):
if self.warmup_steps is None or self.current_steps >= self.warmup_steps:
lr_multiplier = 1.0
else:
lr_multiplier = self.current_steps / self.warmup_steps
return [lr_multiplier * base_lr for base_lr in self.base_lrs]

0 comments on commit 4d80fae

Please sign in to comment.