Skip to content

Commit

Permalink
Add warmup scheduler to pytext (facebookresearch#423)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#423

Warmup scheduler that linearly increases learning rate from 0 to final value, across a specified number of training steps.

Reviewed By: borguz

Differential Revision: D14609360

fbshipit-source-id: b41a6bfcf558588b3e86c3e305d864e97437992a
  • Loading branch information
Michael Wu authored and facebook-github-bot committed Mar 29, 2019
1 parent f0f2543 commit 962f01e
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion pytext/optimizer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def prepare(self, train_iter, total_epochs):
class BatchScheduler(Scheduler):
def prepare(self, train_iter, total_epochs):
self.num_epochs = total_epochs
self.steps_per_epoch = train_iter.total_num_batches
self.steps_per_epoch = getattr(train_iter, "total_num_batches", None)


class LmFineTuning(_LRScheduler, BatchScheduler):
Expand Down Expand Up @@ -300,3 +300,39 @@ def from_config(cls, config: Config, optimizer: Optimizer):

def step_epoch(self, metrics=None, epoch=None):
self.step(epoch)


class WarmupScheduler(_LRScheduler, BatchScheduler):
"""
Scheduler to linearly increase learning rate from 0 to final value at the beginning
of training.
"""

class Config(BatchScheduler.Config):
#: number of training steps over which to increase learning rate
warmup_steps: int = 10000

@classmethod
def from_config(cls, config: Config, optimizer: Optimizer):
return cls(optimizer, config.warmup_steps)

def __init__(self, optimizer, warmup_steps):
assert warmup_steps > 0
self.warmup_steps = warmup_steps
self.current_steps = 0
super().__init__(optimizer)

def prepare(self, train_iter, total_epochs):
super().prepare(train_iter, total_epochs)
self.step_batch() # initialize learning rate

def step_batch(self):
self.current_steps += 1
self.step()

def get_lr(self):
if self.current_steps >= self.warmup_steps:
lr_multiplier = 1.0
else:
lr_multiplier = self.current_steps / self.warmup_steps
return [lr_multiplier * base_lr for base_lr in self.base_lrs]

0 comments on commit 962f01e

Please sign in to comment.