diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 0d24286fb40cd..37a46f53707f1 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -45,6 +45,7 @@ 'MultiplicativeDecay', 'OneCycleLR', 'CyclicLR', + 'LinearLR', ] @@ -2229,6 +2230,125 @@ def get_lr(self): return lr +class LinearLR(LRScheduler): + r""" + Set the learning rate according to linear scheduler. + The learning rate will be firstly multiplied by start_factor and linearly increase to end learning rate. + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + total_steps (int): Number of iterations that the learning_rate reaches end learning_rate. + start_factor (float): Start learning rate is defined by `start_factor * learning_rate` . Default: 1./3. + end_factor (float) End learning rate is defined by `end_factor * learning_rate`. Default: 1.0. + last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate. + verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``LinearLR`` instance to schedule learning rate. + + Examples: + .. code-block:: python + :name: code-dynamic + + >>> # Example1: train on default dynamic graph mode + >>> import paddle + >>> import numpy as np + + >>> # train on default dynamic graph mode + >>> linear = paddle.nn.Linear(10, 10) + >>> scheduler = paddle.optimizer.lr.LinearLR(learning_rate=0.5, total_steps=5, verbose=True) + >>> sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) + >>> for epoch in range(5): + ... for batch_id in range(20): + ... x = paddle.uniform([10, 10]) + ... out = linear(x) + ... loss = paddle.mean(out) + ... loss.backward() + ... sgd.step() + ... sgd.clear_gradients() + ... scheduler.step() + + .. code-block:: python + :name: code-static + + >>> # Example2: train on static graph mode + >>> import paddle + >>> import numpy as np + >>> paddle.enable_static() + >>> main_prog = paddle.static.Program() + >>> start_prog = paddle.static.Program() + >>> with paddle.static.program_guard(main_prog, start_prog): + ... x = paddle.static.data(name='x', shape=[None, 4, 5]) + ... y = paddle.static.data(name='y', shape=[None, 4, 5]) + ... z = paddle.static.nn.fc(x, 100) + ... loss = paddle.mean(z) + ... scheduler = paddle.optimizer.lr.LinearLR(learning_rate=0.5, + ... total_steps=5, verbose=True) + ... sgd = paddle.optimizer.SGD(learning_rate=scheduler) + ... sgd.minimize(loss) + ... + >>> exe = paddle.static.Executor() + >>> exe.run(start_prog) + >>> for epoch in range(5): + ... for batch_id in range(20): + ... out = exe.run( + ... main_prog, + ... feed={ + ... 'x': np.random.randn(3, 4, 5).astype('float32'), + ... 'y': np.random.randn(3, 4, 5).astype('float32') + ... }, + ... fetch_list=loss.name) + ... scheduler.step() + """ + + def __init__( + self, + learning_rate, + total_steps, + start_factor=1.0 / 3, + end_factor=1.0, + last_epoch=-1, + verbose=False, + ): + if start_factor > 1.0 or start_factor <= 0: + raise ValueError( + "`start_factor` must be greater than 0 and less or equal to 1, but got {}".format( + start_factor + ) + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "`end_factor` must be greater than 0 and less than 1, but got {}".format( + end_factor + ) + ) + + if total_steps <= 0: + raise ValueError( + f"`total_steps` must be greater than 0, but got {total_steps}" + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_steps = total_steps + + super().__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lr * self.start_factor + elif self.last_epoch > self.total_steps: + return self.last_lr + else: + base_lr = self.total_steps * self.start_factor + cur_factor = self.end_factor - self.start_factor + factor = 1.0 + cur_factor / ( + base_lr + (self.last_epoch - 1) * cur_factor + ) + return self.last_lr * factor + + def autoincreased_step_counter(counter_name=None, begin=1, step=1): """ :api_attr: Static Graph diff --git a/test/legacy_test/test_lr_scheduler.py b/test/legacy_test/test_lr_scheduler.py index 54484ecc6ad2c..ba1f712dce2fd 100644 --- a/test/legacy_test/test_lr_scheduler.py +++ b/test/legacy_test/test_lr_scheduler.py @@ -464,6 +464,31 @@ def exp_range(x): return base_learning_rate + base_height * scale_fn(eval(scale_mode)) +linear_last_lr = None + + +def linear_lr( + epoch_num, + learning_rate, + total_steps, + start_factor=1.0 / 3, + end_factor=1.0, + verbose=False, +): + global linear_last_lr + if epoch_num == 0: + linear_last_lr = learning_rate * start_factor + return linear_last_lr + elif epoch_num > total_steps: + return linear_last_lr + else: + base_lr = total_steps * start_factor + cur_factor = end_factor - start_factor + factor = 1.0 + cur_factor / (base_lr + (epoch_num - 1) * cur_factor) + linear_last_lr *= factor + return linear_last_lr + + class TestLRScheduler(unittest.TestCase): def _test_static(self, python_func, paddle_api, kwarg, place): scheduler = paddle_api(**kwarg) @@ -711,6 +736,19 @@ def test_scheduler(self): paddle.optimizer.lr.PiecewiseDecay( boundaries=[100, 200], values=[0.5, 0.1] ) + # check minus total_steps + with self.assertRaises(ValueError): + paddle.optimizer.lr.LinearLR(learning_rate=1, total_steps=-1) + # check start_factor + with self.assertRaises(ValueError): + paddle.optimizer.lr.LinearLR( + learning_rate=1, total_steps=5, start_factor=2 + ) + # check end_factor + with self.assertRaises(ValueError): + paddle.optimizer.lr.LinearLR( + learning_rate=1, total_steps=5, end_factor=2 + ) func_api_kwargs = [ ( @@ -944,6 +982,28 @@ def test_scheduler(self): "verbose": False, }, ), + ( + linear_lr, + paddle.optimizer.lr.LinearLR, + { + "learning_rate": 0.2, + "total_steps": 40, + "start_factor": 0.5, + "end_factor": 1, + "verbose": False, + }, + ), + ( + linear_lr, + paddle.optimizer.lr.LinearLR, + { + "learning_rate": 0.2, + "total_steps": 5, + "start_factor": 0.2, + "end_factor": 0.5, + "verbose": False, + }, + ), ] for python_func, paddle_api, kwarg in func_api_kwargs: