Skip to content

[Typing] Refine typing for jit.save and jit.to_static #65301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IntSequence as IntSequence,
NestedNumbericSequence as NestedNumbericSequence,
NestedSequence as NestedSequence,
NestedStructure as NestedStructure,
Numberic as Numberic,
NumbericSequence as NumbericSequence,
TensorIndex as TensorIndex,
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/_typing/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Sequence,
Expand Down Expand Up @@ -49,6 +50,9 @@
_T = TypeVar("_T")

NestedSequence = Union[_T, Sequence["NestedSequence[_T]"]]
NestedStructure = Union[
_T, Dict[str, "NestedStructure[_T]"], Sequence["NestedStructure[_T]"]
]
IntSequence = Sequence[int]
NumbericSequence = Sequence[Numberic]
NestedNumbericSequence: TypeAlias = NestedSequence[Numberic]
Expand Down
62 changes: 34 additions & 28 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
from collections.abc import Sequence
from contextlib import contextmanager
from types import ModuleType
from typing import Any, Callable, Protocol, TypedDict, TypeVar, overload
from typing import (
Any,
Callable,
Protocol,
TypedDict,
TypeVar,
overload,
)

from typing_extensions import (
Literal,
Expand All @@ -37,7 +44,7 @@
)

import paddle
from paddle._typing import NestedSequence
from paddle._typing import NestedStructure
from paddle.base import core, dygraph
from paddle.base.compiler import (
BuildStrategy,
Expand Down Expand Up @@ -84,7 +91,6 @@
ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True)


_F = TypeVar('_F', bound=Callable[..., Any])
_LayerT = TypeVar("_LayerT", bound=Layer)
_RetT = TypeVar("_RetT")
_InputT = ParamSpec("_InputT")
Expand Down Expand Up @@ -152,12 +158,12 @@ def _check_and_set_backend(backend, build_strategy):
build_strategy.build_cinn_pass = True


class ToStaticOptions(TypedDict):
class _ToStaticOptions(TypedDict):
property: NotRequired[bool]
full_graph: NotRequired[bool]


class ToStaticDecorator(Protocol):
class _ToStaticDecorator(Protocol):
@overload
def __call__(self, function: _LayerT) -> _LayerT:
...
Expand All @@ -172,44 +178,33 @@ def __call__(
@overload
def to_static(
function: _LayerT,
input_spec: NestedSequence[InputSpec] | None = ...,
input_spec: NestedStructure[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
**kwargs: Unpack[_ToStaticOptions],
) -> _LayerT:
...


@overload
def to_static(
function: Callable[_InputT, _RetT],
input_spec: NestedSequence[InputSpec] | None = ...,
input_spec: NestedStructure[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
**kwargs: Unpack[_ToStaticOptions],
) -> StaticFunction[_InputT, _RetT]:
...


@overload
def to_static(
function: Any,
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
) -> Any:
...


@overload
def to_static(
function: None = ...,
input_spec: NestedSequence[InputSpec] | None = ...,
input_spec: NestedStructure[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
) -> ToStaticDecorator:
**kwargs: Unpack[_ToStaticOptions],
) -> _ToStaticDecorator:
...


Expand Down Expand Up @@ -334,15 +329,15 @@ def decorated(python_func):
return decorated


class NotToStaticDecorator(Protocol):
class _NotToStaticDecorator(Protocol):
@overload
def __call__(
self, func: Callable[_InputT, _RetT]
) -> Callable[_InputT, _RetT]:
...

@overload
def __call__(self, func: None = ...) -> NotToStaticDecorator:
def __call__(self, func: None = ...) -> _NotToStaticDecorator:
...


Expand All @@ -352,7 +347,7 @@ def not_to_static(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:


@overload
def not_to_static(func: None = ...) -> NotToStaticDecorator:
def not_to_static(func: None = ...) -> _NotToStaticDecorator:
...


Expand Down Expand Up @@ -861,8 +856,19 @@ def _remove_save_pre_hook(hook):
_save_pre_hooks_lock.release()


class _SaveFunction(Protocol):
def __call__(
self,
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = ...,
**configs: Unpack[_SaveLoadOptions],
) -> None:
...


@wrap_decorator
def _run_save_pre_hooks(func: _F) -> _F:
def _run_save_pre_hooks(func: _SaveFunction) -> _SaveFunction:
def wrapper(
layer: Layer | Callable[..., Any],
path: str,
Expand All @@ -874,7 +880,7 @@ def wrapper(
hook(layer, input_spec, configs)
func(layer, path, input_spec, **configs)

return wrapper # type: ignore
return wrapper


def _save_property(filename: str, property_vals: list[tuple[Any, str]]):
Expand Down
36 changes: 25 additions & 11 deletions python/paddle/static/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any

import numpy as np
import numpy.typing as npt
from typing_extensions import Self

import paddle
from paddle._typing import DTypeLike, ShapeLike
from paddle._typing import DTypeLike, ShapeLike, Size1
from paddle.base import Variable, core
from paddle.base.data_feeder import check_type
from paddle.base.framework import (
Expand All @@ -36,6 +39,9 @@

from ..base.variable_index import _setitem_static

if TYPE_CHECKING:
from paddle import Tensor

__all__ = []


Expand Down Expand Up @@ -216,7 +222,13 @@ class InputSpec:
InputSpec(shape=(-1, 1), dtype=paddle.int64, name=label, stop_gradient=False)
"""

def __init__(self, shape, dtype='float32', name=None, stop_gradient=False):
def __init__(
self,
shape: ShapeLike,
dtype: DTypeLike = 'float32',
name: str | None = None,
stop_gradient: bool = False,
) -> None:
# replace `None` in shape with -1
self.shape = self._verify(shape)
# convert dtype into united representation
Expand All @@ -231,11 +243,11 @@ def __init__(self, shape, dtype='float32', name=None, stop_gradient=False):
def _create_feed_layer(self):
return data(self.name, shape=self.shape, dtype=self.dtype)

def __repr__(self):
def __repr__(self) -> str:
return f'{type(self).__name__}(shape={self.shape}, dtype={self.dtype}, name={self.name}, stop_gradient={self.stop_gradient})'

@classmethod
def from_tensor(cls, tensor, name=None):
def from_tensor(cls, tensor: Tensor, name: str | None = None) -> Self:
"""
Generates a InputSpec based on the description of input tensor.

Expand Down Expand Up @@ -267,7 +279,9 @@ def from_tensor(cls, tensor, name=None):
)

@classmethod
def from_numpy(cls, ndarray, name=None):
def from_numpy(
cls, ndarray: npt.NDArray[Any], name: str | None = None
) -> Self:
"""
Generates a InputSpec based on the description of input np.ndarray.

Expand All @@ -291,7 +305,7 @@ def from_numpy(cls, ndarray, name=None):
"""
return cls(ndarray.shape, ndarray.dtype, name)

def batch(self, batch_size):
def batch(self, batch_size: int | Size1) -> Self:
"""
Inserts `batch_size` in front of the `shape`.

Expand All @@ -317,7 +331,7 @@ def batch(self, batch_size):
raise ValueError(
f"Length of batch_size: {batch_size} shall be 1, but received {len(batch_size)}."
)
batch_size = batch_size[1]
batch_size = batch_size[0]
elif not isinstance(batch_size, int):
raise TypeError(
f"type(batch_size) shall be `int`, but received {type(batch_size).__name__}."
Expand All @@ -328,7 +342,7 @@ def batch(self, batch_size):

return self

def unbatch(self):
def unbatch(self) -> Self:
"""
Removes the first element of `shape`.

Expand Down Expand Up @@ -374,7 +388,7 @@ def _verify(self, shape):

return tuple(shape)

def __hash__(self):
def __hash__(self) -> int:
# Note(Aurelius84): `name` is not considered as a field to compute hashkey.
# Because it's no need to generate a new program in following cases while using
# @paddle.jit.to_static.
Expand All @@ -391,13 +405,13 @@ def __hash__(self):
# x_var and x_np hold same shape and dtype, they should also share a same program.
return hash((tuple(self.shape), self.dtype, self.stop_gradient))

def __eq__(self, other):
def __eq__(self, other: Self) -> bool:
slots = ['shape', 'dtype', 'name', 'stop_gradient']
return type(self) is type(other) and all(
getattr(self, attr) == getattr(other, attr) for attr in slots
)

def __ne__(self, other):
def __ne__(self, other) -> bool:
return not self == other


Expand Down
Loading