Skip to content

Commit

Permalink
[Typing][A-30,A-31] Add type annotations for `paddle/nn/initializer/d…
Browse files Browse the repository at this point in the history
…irac.py` (PaddlePaddle#65087)



---------

Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
  • Loading branch information
gouzil and SigureMo authored Jun 13, 2024
1 parent 3768840 commit f2746e1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
14 changes: 9 additions & 5 deletions python/paddle/nn/initializer/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import paddle
from paddle import _C_ops, in_dynamic_mode, pir
from paddle.utils import unique_name
Expand Down Expand Up @@ -42,9 +44,9 @@ class Dirac(Initializer):
where, ``N`` is the minimum value of ``in_channels`` and ``out_channels``
Args:
groups(int, optional): 0-dimension of the Tensor will be divided by groups,
groups(int|None, optional): 0-dimension of the Tensor will be divided by groups,
each group has the same value. Default: 1.
name(str, optional): The default value is None. Normally there is no need for user to set this
name(str|None, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down Expand Up @@ -88,19 +90,21 @@ class Dirac(Initializer):
[0., 0., 0., 0.]]])
"""

def __init__(self, groups=1, name=None):
def __init__(self, groups: int = 1, name: str | None = None) -> None:
assert groups > 0 and isinstance(
groups, int
), " 'groups' must be a positive integer. "
super().__init__()
self._groups = groups

def __call__(self, var, block=None):
def __call__(
self, var: paddle.Tensor, block: pir.Block | None = None
) -> paddle.Tensor:
"""Initialize the input tensor with dirac initializer.
Args:
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
block(Block|None, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
Expand Down
30 changes: 22 additions & 8 deletions python/paddle/nn/initializer/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import math

import numpy as np

import paddle

from ...base.framework import (
EagerParamBase,
default_main_program,
Expand All @@ -39,23 +43,31 @@ class Initializer:
def __init__(self):
pass

def __call__(self, param, block=None):
def __call__(
self, param: paddle.Tensor, block: paddle.pir.Block | None = None
):
if not lazy_init_helper().state:
return self.forward(param, block)

return self._lazy_init(param, block)

def forward(self, param, block=None):
def forward(
self, param: paddle.Tensor, block: paddle.pir.Block | None = None
):
"""Add corresponding initialization operations to the network."""
raise NotImplementedError()

def _lazy_init(self, param, block=None):
def _lazy_init(
self, param: paddle.Tensor, block: paddle.pir.Block | None = None
):
"""
Apply lazy initialization
"""
assert in_dygraph_mode()

def init_op_creator(forward, param, block):
def init_op_creator(
forward, param: paddle.Tensor, block: paddle.pir.Block | None
):
new_var = param._to_static_var(True, block=block)
# Record initializer operator
with lazy_init_helper():
Expand All @@ -69,13 +81,13 @@ def init_op_creator(forward, param, block):

return param

def _check_block(self, block):
def _check_block(self, block: paddle.pir.Block | None) -> paddle.pir.Block:
if block is None:
block = default_main_program().global_block()

return block

def _compute_fans(self, var):
def _compute_fans(self, var: paddle.Tensor) -> tuple[int, int]:
"""Compute the fan_in and the fan_out for layers
This method computes the fan_in and the fan_out
Expand Down Expand Up @@ -115,15 +127,17 @@ def _compute_fans(self, var):
return (fan_in, fan_out)


def calculate_gain(nonlinearity, param=None):
def calculate_gain(
nonlinearity: str, param: bool | float | None = None
) -> float:
"""
Get the recommended ``gain`` value of some nonlinearity function. ``gain`` value can be used in some
``paddle.nn.initializer`` api to adjust the initialization value.
Args:
nonlinearity(str): name of nonlinearity activation function. If it is a linear function, such as:
`linear/conv1d/conv2d/conv3d/conv1d_transpose/conv2d_transpose/conv3d_transpose` , 1.0 will be returned.
param(bool|int|float, optional): optional parameter for somme nonlinearity function. Now, it only applies to
param(bool|int|float|None, optional): optional parameter for somme nonlinearity function. Now, it only applies to
'leaky_relu'. Default: None, it will be calculated as 0.01 in the formula.
Returns:
Expand Down

0 comments on commit f2746e1

Please sign in to comment.