Skip to content

Commit 568209f

Browse files
authored
[Typing][B-17] Add type annotations for python/paddle/distribution/kl.py (#65776)
1 parent f6ad654 commit 568209f

File tree

1 file changed

+15
-4
lines changed
  • python/paddle/distribution

1 file changed

+15
-4
lines changed

python/paddle/distribution/kl.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from __future__ import annotations
16+
1417
import functools
1518
import warnings
19+
from typing import TYPE_CHECKING, Callable, TypeVar
1620

1721
import paddle
1822
from paddle.distribution.bernoulli import Bernoulli
@@ -35,12 +39,17 @@
3539
from paddle.distribution.uniform import Uniform
3640
from paddle.framework import in_dynamic_mode
3741

42+
if TYPE_CHECKING:
43+
from paddle import Tensor
44+
45+
_T = TypeVar('_T')
46+
3847
__all__ = ["register_kl", "kl_divergence"]
3948

4049
_REGISTER_TABLE = {}
4150

4251

43-
def kl_divergence(p, q):
52+
def kl_divergence(p: Distribution, q: Distribution) -> Tensor:
4453
r"""
4554
Kullback-Leibler divergence between distribution p and q.
4655
@@ -72,7 +81,9 @@ def kl_divergence(p, q):
7281
return _dispatch(type(p), type(q))(p, q)
7382

7483

75-
def register_kl(cls_p, cls_q):
84+
def register_kl(
85+
cls_p: type[Distribution], cls_q: type[Distribution]
86+
) -> Callable[[_T], _T]:
7687
"""Decorator for register a KL divergence implementation function.
7788
7889
The ``kl_divergence(p, q)`` function will search concrete implementation
@@ -82,8 +93,8 @@ def register_kl(cls_p, cls_q):
8293
implementation function by the decorator.
8394
8495
Args:
85-
cls_p (Distribution): The Distribution type of Instance p. Subclass derived from ``Distribution``.
86-
cls_q (Distribution): The Distribution type of Instance q. Subclass derived from ``Distribution``.
96+
cls_p (type[Distribution]): The Distribution type of Instance p. Subclass derived from ``Distribution``.
97+
cls_q (type[Distribution]): The Distribution type of Instance q. Subclass derived from ``Distribution``.
8798
8899
Examples:
89100
.. code-block:: python

0 commit comments

Comments
 (0)