Skip to content

Commit

Permalink
[Typing][A-9] Add type annotations for paddle/tensor/ops.py (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Jun 19, 2024
1 parent d95a4cb commit fb7b961
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 50 deletions.
13 changes: 9 additions & 4 deletions python/paddle/tensor/layer_function_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
from typing import TYPE_CHECKING

from paddle import _C_ops, _legacy_C_ops

Expand All @@ -27,6 +29,9 @@
in_dynamic_or_pir_mode,
)

if TYPE_CHECKING:
from paddle import Tensor

__all__ = []


Expand All @@ -46,7 +51,7 @@ def _convert_(name):
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def generate_layer_fn(op_type):
def generate_layer_fn(op_type: str):
"""Register the Python layer for an Operator.
Args:
Expand Down Expand Up @@ -124,7 +129,7 @@ def infer_and_check_dtype(op_proto, *args, **kwargs):
dtype = core.VarDesc.VarType.FP32
return dtype

def func(*args, **kwargs):
def func(*args, **kwargs) -> Tensor:
helper = LayerHelper(op_type, **kwargs)

dtype = infer_and_check_dtype(op_proto, *args, **kwargs)
Expand Down Expand Up @@ -160,7 +165,7 @@ def func(*args, **kwargs):
return func


def generate_activation_fn(op_type):
def generate_activation_fn(op_type: str):
"""Register the Python layer for an Operator without Attribute.
Args:
Expand All @@ -171,7 +176,7 @@ def generate_activation_fn(op_type):
"""

def func(x, name=None):
def func(x, name: str | None = None) -> Tensor:
if in_dynamic_or_pir_mode():
if hasattr(_C_ops, op_type):
op = getattr(_C_ops, op_type)
Expand Down
Loading

0 comments on commit fb7b961

Please sign in to comment.