Skip to content

Commit

Permalink
Merge pull request #66 from mmmwhy/wip-fy
Browse files Browse the repository at this point in the history
feat(schedule): cosine schedule with warmup
  • Loading branch information
mmmwhy authored Feb 12, 2022
2 parents 1ab7922 + 87c8886 commit 83eaa63
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
* **decode:** add some transformer decode code ([52b044b](https://github.com/mmmwhy/pure_attention/commit/52b044b0fa79dcb3b9ba8fcd2747f05bc43de808))
* **layers:** fix import for layerNorm ([eb61b31](https://github.com/mmmwhy/pure_attention/commit/eb61b313458ac18bf4b15271fee2cf7e39f8afde))
* **nlp:** init basic bert code ([f9cb13a](https://github.com/mmmwhy/pure_attention/commit/f9cb13a3e811eb8c44ba8ff1373d688311426927))
* **schedule:** cosine schedule with warmup ([085a36e](https://github.com/mmmwhy/pure_attention/commit/085a36e1a0d55c64daf2514d5dff6dec1d57b354))


### Performance Improvements

* **runner:** logger missing_keys and unexpected_key in runner ([69c8c78](https://github.com/mmmwhy/pure_attention/commit/69c8c781c7053c066d947087e98814e6132c8847))
* **runner:** logger missing_keys and unexpected_key in runner ([a9e3ff9](https://github.com/mmmwhy/pure_attention/commit/a9e3ff9ca7771fac648c427a98d0dd5414956cbd))



45 changes: 45 additions & 0 deletions pure_attention/common/schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# @author: mmmwhy <mmmwhy@mail.ustc.edu.cn>
# @date: 2022/02/12
#
""""""
import math

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR


def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
"""
搬运自: https://github.com/huggingface/transformers/blob/2e9af294940083915ccb2740a7c8d5b154194f15/src/transformers/optimization.py#L103-L134
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

return LambdaLR(optimizer, lr_lambda, last_epoch)

0 comments on commit 83eaa63

Please sign in to comment.