Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add polynomial scheduler #7260

Merged
merged 16 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ Optimizers
StepLR,
MultiStepLR,
ExponentialLR,
ReduceLROnPlateau
ReduceLROnPlateau,
PolynomialLR
109 changes: 109 additions & 0 deletions python/oneflow/nn/optimizer/polynomial_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import math

from .lr_scheduler import LrScheduler


class PolynomialLR(LrScheduler):
r"""
This operator creates a polynomial decayed learning rate scheduler.
The learning rate will be updated as follows:

If cycle is `True`, the equation is:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档每一段可以加点空行,可以参考 https://github.com/Oneflow-Inc/OneTeam/issues/94 编译一下文档,看是否符合预期

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我试试


.. math::
\begin{aligned}
& decay\_batch = decay\_batch*ceil(\frac{current\_batch}{decay\_batch}) \\
& learning\_rate = (base\_lr-end\_lr)*(1-\frac{current\_batch}{decay\_batch})^{power}+end\_lr
\end{aligned}

If cycle is `False`, the equation is:

.. math::
\begin{aligned}
& decay\_batch = min(decay\_batch, current\_batch) \\
& learning\_rate = (base\_lr-end\_lr)*(1-\frac{current\_batch}{decay\_batch})^{power}+end\_lr
\end{aligned}

Args:
optimizer (Optimizer): Wrapper optimizer.
steps (int): The decayed steps.
end_learning_rate (float, optional): The final learning rate. Defaults to 0.0001.
power (float, optional): The power of polynomial. Defaults to 1.0.
cycle (bool, optional): If cycle is True, the scheduler will decay the learning rate every decay steps. Defaults to False.

For example:

.. code-block:: python

import oneflow as flow

...
polynomial_scheduler = flow.optim.lr_scheduler.PolynomialLR(
optimizer, steps=5, end_learning_rate=0.00001, power=2
)

for epoch in range(num_epoch):
train(...)
polynomial_scheduler.step()
"""

def __init__(
self,
optimizer,
steps: int,
end_learning_rate: float = 0.0001,
power: float = 1.0,
cycle: bool = False,
last_step: int = -1,
verbose: bool = False,
):
assert steps > 0, f"steps must greater than zero, but got {steps}"
self.max_decay_steps = steps
self.end_learning_rate = end_learning_rate
self.power = power
self.cycle = cycle
super().__init__(optimizer, last_step, verbose)

def get_lr(self):
decay_batch = self.max_decay_steps
cur_batch = self.last_step
if self.cycle:
if cur_batch == 0:
cur_batch = 1
decay_batch = decay_batch * math.ceil(cur_batch / decay_batch)
else:
cur_batch = min(cur_batch, decay_batch)
return [
(base_lr - self.end_learning_rate)
* ((1 - cur_batch / decay_batch) ** (self.power))
+ self.end_learning_rate
for base_lr in self.base_lrs
]

def _generate_conf_for_graph(self, opt_confs):
for opt_conf in opt_confs:
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()
learning_rate_decay_conf.mutable_polynomial_conf().set_decay_batches(
self.max_decay_steps
)
learning_rate_decay_conf.mutable_polynomial_conf().set_end_learning_rate(
self.end_learning_rate
)
learning_rate_decay_conf.mutable_polynomial_conf().set_power(self.power)
learning_rate_decay_conf.mutable_polynomial_conf().set_cycle(self.cycle)
1 change: 1 addition & 0 deletions python/oneflow/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from oneflow.nn.optimizer.exponential_lr import ExponentialLR
from oneflow.nn.optimizer.warm_up_lr import WarmUpLR
from oneflow.nn.optimizer.reduce_lr_on_plateau import ReduceLROnPlateau
from oneflow.nn.optimizer.polynomial_lr import PolynomialLR
13 changes: 13 additions & 0 deletions python/oneflow/test/graph/test_graph_lrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,19 @@ def _lr_fn(parameters):
_test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn)
_test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn)

def test_polynomial_lr(test_case):
def _lr_fn(parameters):
of_sgd = flow.optim.SGD(parameters, lr=0.001)

lr = flow.optim.lr_scheduler.PolynomialLR(
of_sgd, steps=10, end_learning_rate=0.00001, power=2, cycle=True
)
return of_sgd, lr

_test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn)

_test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn)


if __name__ == "__main__":
unittest.main()
51 changes: 49 additions & 2 deletions python/oneflow/test/modules/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from test_util import GenArgDict


def compare_with_troch_reduce_lr(
def compare_with_torch_reduce_lr(
test_case, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps,
):
optimizer_flow = flow.optim.SGD(
Expand Down Expand Up @@ -218,6 +218,53 @@ def lambda_lr_step(base_lrs, current_step):
for (lr1, lr2) in zip(lambda_lr.get_last_lr(), new_lrs):
test_case.assertAlmostEqual(lr1, lr2, places=5)

def test_polynomial_lr(test_case):
optimizer = flow.optim.SGD(
[{"params": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr
)

def polynomial_lr_step(base_lr, end_lr, step, decay_steps, power, cycle):
if cycle:
if step == 0:
step = 1
decay_steps = decay_steps * math.ceil(step / decay_steps)
step = min(step, decay_steps)
return (base_lr - end_lr) * (1 - step / decay_steps) ** power + end_lr

decay_steps = 100
end_learning_rate = 1e-5
power = 2
cycle = True
poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR(
optimizer, decay_steps, end_learning_rate, power, cycle
)
# step(0) will be invoked in LrScheduler.__init__
new_lr = polynomial_lr_step(
TestLrScheduler.base_lr, end_learning_rate, 0, decay_steps, power, cycle
)
test_case.assertAlmostEqual(poly_decay_lr.get_last_lr()[0], new_lr, places=4)
for i in range(1, 21):
poly_decay_lr.step()
new_lr = polynomial_lr_step(
TestLrScheduler.base_lr, end_learning_rate, i, decay_steps, power, cycle
)
test_case.assertAlmostEqual(
poly_decay_lr.get_last_lr()[0], new_lr, places=4
)

cycle = True
poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR(
optimizer, decay_steps, end_learning_rate, power, cycle
)
for i in range(1, 21):
poly_decay_lr.step()
new_lr = polynomial_lr_step(
TestLrScheduler.base_lr, end_learning_rate, i, decay_steps, power, cycle
)
test_case.assertAlmostEqual(
poly_decay_lr.get_last_lr()[0], new_lr, places=4
)

def test_reduce_lr_on_plateau(test_case):
arg_dict = OrderedDict()
arg_dict["mode"] = ["min", "max"]
Expand All @@ -229,7 +276,7 @@ def test_reduce_lr_on_plateau(test_case):
arg_dict["min_lr"] = [0, 1e-3]
arg_dict["eps"] = [1e-5, 1e-8]
for arg in GenArgDict(arg_dict):
compare_with_troch_reduce_lr(test_case, **arg)
compare_with_torch_reduce_lr(test_case, **arg)

def test_warmup_scheduler_save_and_load(test_case):
param = flow.nn.Parameter(flow.ones(3, 4))
Expand Down