Skip to content

Commit

Permalink
[Typing][A-43] Add type annotations for paddle/optimizer/lamb.py (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Jun 18, 2024
1 parent 2f9d803 commit 2b677f3
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions python/paddle/optimizer/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Sequence

from paddle import _C_ops
from paddle.base.executor import global_scope

from ..base import core, framework
from ..base.framework import Variable
from .optimizer import Optimizer

if TYPE_CHECKING:
from paddle import Tensor
from paddle.nn.clip import GradientClipBase

from .optimizer import _ParameterConfig

__all__ = []


Expand Down Expand Up @@ -60,7 +70,7 @@ class Lamb(Optimizer):
beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
Default 0.999.
epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
parameters (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
parameters (list|tuple|None, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. And you can specify different options for \
different parameter groups such as the learning rate, weight decay, etc, \
then the parameters are list of dict. Note that the learning_rate in parameter groups \
Expand All @@ -71,10 +81,11 @@ class Lamb(Optimizer):
( :ref:`api_paddle_base_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_base_clip_ClipGradByNorm` ,
:ref:`api_paddle_base_clip_ClipGradByValue` ). If you want better convergence, it is recommended
to use :ref:`api_paddle_base_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
exclude_from_weight_decay_fn (function, optional): whether to skip weight decay for a parameter when this function returns True while take the parameter as input.
exclude_from_weight_decay_fn (Callable|None, optional): whether to skip weight decay for a parameter when this function returns True while take the parameter as input.
multi_precision (bool, optional) - Whether to use it during weight updates multi-precision, Default False。
always_adapt (bool, optional): whether to use Layer-wise LR adaptation. By default, skip adaptation on parameters that are
excluded from weight decay, unless always_adapt == True, then always enable LR adaptation.
name(str|None): For detailed information, please refer to
name(str|None, optional): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Examples:
.. code-block:: python
Expand All @@ -100,18 +111,18 @@ class Lamb(Optimizer):

def __init__(
self,
learning_rate=0.001,
lamb_weight_decay=0.01,
beta1=0.9,
beta2=0.999,
epsilon=1e-6,
parameters=None,
grad_clip=None,
exclude_from_weight_decay_fn=None,
multi_precision=False,
always_adapt=False,
name=None,
):
learning_rate: float | Tensor = 0.001,
lamb_weight_decay: float = 0.01,
beta1: float = 0.9,
beta2: float = 0.999,
epsilon: float = 1e-6,
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None,
grad_clip: GradientClipBase | None = None,
exclude_from_weight_decay_fn: Callable[[Tensor], bool] | None = None,
multi_precision: bool = False,
always_adapt: bool = False,
name: str | None = None,
) -> None:
assert learning_rate is not None
assert beta1 is not None
assert beta2 is not None
Expand Down

0 comments on commit 2b677f3

Please sign in to comment.