Skip to content

Commit fb2bd26

Browse files
【Hackathon 6th No.5】Add chi2/LKJCholesky API to Paddle -part (#63883)
* update chi2 && lkj cholesky * update lkj * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * 【Hackathon 6th No.5】Chi2 / LKJCholesky API * fix * fix test * fix * fix * fix * fix * fix * fix doc * fix * fix * fix lkj * fix * fix test * fix * fix * fix * fix * fix * fix * fix * fix doc * add des & set timeout * fix * fix * add type check * fix * add bad params && shapes * fix chi2 * fix chi2 * fix * fix * modify chi2 * Update python/paddle/distribution/chi2.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * Update python/paddle/distribution/chi2.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> --------- Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
1 parent c0ea68d commit fb2bd26

File tree

8 files changed

+1358
-0
lines changed

8 files changed

+1358
-0
lines changed

python/paddle/distribution/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .binomial import Binomial
1919
from .categorical import Categorical
2020
from .cauchy import Cauchy
21+
from .chi2 import Chi2
2122
from .continuous_bernoulli import ContinuousBernoulli
2223
from .dirichlet import Dirichlet
2324
from .distribution import Distribution
@@ -29,6 +30,7 @@
2930
from .independent import Independent
3031
from .kl import kl_divergence, register_kl
3132
from .laplace import Laplace
33+
from .lkj_cholesky import LKJCholesky
3234
from .lognormal import LogNormal
3335
from .multinomial import Multinomial
3436
from .multivariate_normal import MultivariateNormal
@@ -58,6 +60,7 @@
5860
'Beta',
5961
'Categorical',
6062
'Cauchy',
63+
'Chi2',
6164
'ContinuousBernoulli',
6265
'Dirichlet',
6366
'Distribution',
@@ -73,6 +76,7 @@
7376
'TransformedDistribution',
7477
'Laplace',
7578
'LogNormal',
79+
'LKJCholesky',
7680
'Gamma',
7781
'Gumbel',
7882
'Geometric',

python/paddle/distribution/chi2.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import paddle
15+
from paddle.base.data_feeder import check_type, convert_dtype
16+
from paddle.base.framework import Variable
17+
from paddle.distribution.gamma import Gamma
18+
from paddle.framework import in_dynamic_mode
19+
20+
__all__ = ["Chi2"]
21+
22+
23+
class Chi2(Gamma):
24+
r"""
25+
Creates a Chi-squared distribution parameterized by shape parameter.
26+
This is exactly equivalent to Gamma(concentration=0.5*df, rate=0.5), :ref:`api_paddle_distribution_Gamma`.
27+
28+
Args:
29+
df (float or Tensor): shape parameter of the distribution
30+
31+
Example:
32+
.. code-block:: python
33+
34+
>>> import paddle
35+
>>> m = paddle.distribution.Chi2(paddle.to_tensor([1.0]))
36+
>>> sample = m.sample()
37+
>>> sample.shape
38+
[1]
39+
40+
"""
41+
42+
def __init__(self, df):
43+
if not in_dynamic_mode():
44+
check_type(
45+
df,
46+
'df',
47+
(float, Variable),
48+
'Chi2',
49+
)
50+
51+
# Get/convert concentration to tensor.
52+
if self._validate_args(df):
53+
self.df = df
54+
self.dtype = convert_dtype(df.dtype)
55+
else:
56+
[self.df] = self._to_tensor(df)
57+
self.dtype = paddle.get_default_dtype()
58+
59+
self.rate = paddle.full_like(self.df, 0.5)
60+
61+
if not paddle.all(self.df > 0):
62+
raise ValueError("The arg of `df` must be positive.")
63+
64+
super().__init__(self.df * 0.5, self.rate)

0 commit comments

Comments
 (0)