11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
14
17
import functools
15
18
import warnings
19
+ from typing import TYPE_CHECKING , Callable , TypeVar
16
20
17
21
import paddle
18
22
from paddle .distribution .bernoulli import Bernoulli
35
39
from paddle .distribution .uniform import Uniform
36
40
from paddle .framework import in_dynamic_mode
37
41
42
+ if TYPE_CHECKING :
43
+ from paddle import Tensor
44
+
45
+ _T = TypeVar ('_T' )
46
+
38
47
__all__ = ["register_kl" , "kl_divergence" ]
39
48
40
49
_REGISTER_TABLE = {}
41
50
42
51
43
- def kl_divergence (p , q ) :
52
+ def kl_divergence (p : Distribution , q : Distribution ) -> Tensor :
44
53
r"""
45
54
Kullback-Leibler divergence between distribution p and q.
46
55
@@ -72,7 +81,9 @@ def kl_divergence(p, q):
72
81
return _dispatch (type (p ), type (q ))(p , q )
73
82
74
83
75
- def register_kl (cls_p , cls_q ):
84
+ def register_kl (
85
+ cls_p : type [Distribution ], cls_q : type [Distribution ]
86
+ ) -> Callable [[_T ], _T ]:
76
87
"""Decorator for register a KL divergence implementation function.
77
88
78
89
The ``kl_divergence(p, q)`` function will search concrete implementation
@@ -82,8 +93,8 @@ def register_kl(cls_p, cls_q):
82
93
implementation function by the decorator.
83
94
84
95
Args:
85
- cls_p (Distribution): The Distribution type of Instance p. Subclass derived from ``Distribution``.
86
- cls_q (Distribution): The Distribution type of Instance q. Subclass derived from ``Distribution``.
96
+ cls_p (type[ Distribution] ): The Distribution type of Instance p. Subclass derived from ``Distribution``.
97
+ cls_q (type[ Distribution] ): The Distribution type of Instance q. Subclass derived from ``Distribution``.
87
98
88
99
Examples:
89
100
.. code-block:: python
0 commit comments