Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
d4747dd
update chi2 && lkj cholesky
cmcamdy Apr 16, 2024
8009776
update lkj
cmcamdy Apr 18, 2024
c298011
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy Apr 25, 2024
190d0d8
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 25, 2024
c131dae
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy Apr 25, 2024
4c2ddfd
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
e818d4b
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
9746325
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
2770dd8
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
7761097
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
791592f
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
49f220e
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
d77527e
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
8d4d74d
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
aa1ebbc
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
9caf814
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
1269701
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy Apr 26, 2024
f1200d8
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy May 8, 2024
4b93065
fix
cmcamdy May 8, 2024
2416887
fix test
cmcamdy May 9, 2024
a53b5b7
fix
cmcamdy May 11, 2024
17468a1
fix
cmcamdy May 11, 2024
7f5e0a4
fix
cmcamdy May 11, 2024
bb3d200
fix
cmcamdy May 12, 2024
fa1a604
fix
cmcamdy May 12, 2024
3114f6a
fix doc
cmcamdy May 12, 2024
4c64aac
fix
cmcamdy May 12, 2024
344548f
fix
cmcamdy May 12, 2024
bb4413f
fix lkj
cmcamdy May 12, 2024
1c8ad85
fix
cmcamdy May 12, 2024
a876596
fix test
cmcamdy May 12, 2024
1a9836b
fix
cmcamdy May 13, 2024
d21c83d
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy May 13, 2024
f55797a
Merge branch 'student_t' of github.com:cmcamdy/Paddle into student_t
cmcamdy May 13, 2024
f5bb90e
fix
cmcamdy May 13, 2024
9277759
fix
cmcamdy May 13, 2024
f2a894f
fix
cmcamdy May 13, 2024
6a46e32
fix
cmcamdy May 13, 2024
9493476
fix
cmcamdy May 13, 2024
3e4a35f
fix
cmcamdy May 14, 2024
610adb0
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy May 14, 2024
65f6cdb
fix doc
cmcamdy May 15, 2024
85de0b3
add des & set timeout
cmcamdy May 23, 2024
95d221c
fix
cmcamdy May 23, 2024
a28babe
fix
cmcamdy May 23, 2024
84dec94
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy May 23, 2024
d64ed4d
add type check
cmcamdy May 24, 2024
8a6b071
fix
cmcamdy May 24, 2024
aae4350
add bad params && shapes
cmcamdy May 27, 2024
12b487a
fix chi2
cmcamdy May 28, 2024
790a834
fix chi2
cmcamdy May 28, 2024
3c4ef82
fix
cmcamdy May 28, 2024
5c23f33
fix
cmcamdy May 28, 2024
fe3edf2
modify chi2
cmcamdy Jun 1, 2024
386519d
Update python/paddle/distribution/chi2.py
cmcamdy Jun 11, 2024
b7cdcc6
Update python/paddle/distribution/chi2.py
cmcamdy Jun 11, 2024
23c87e7
merge develop
cmcamdy Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .binomial import Binomial
from .categorical import Categorical
from .cauchy import Cauchy
from .chi2 import Chi2
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
Expand All @@ -29,6 +30,7 @@
from .independent import Independent
from .kl import kl_divergence, register_kl
from .laplace import Laplace
from .lkj_cholesky import LKJCholesky
from .lognormal import LogNormal
from .multinomial import Multinomial
from .multivariate_normal import MultivariateNormal
Expand Down Expand Up @@ -58,6 +60,7 @@
'Beta',
'Categorical',
'Cauchy',
'Chi2',
'ContinuousBernoulli',
'Dirichlet',
'Distribution',
Expand All @@ -73,6 +76,7 @@
'TransformedDistribution',
'Laplace',
'LogNormal',
'LKJCholesky',
'Gamma',
'Gumbel',
'Geometric',
Expand Down
64 changes: 64 additions & 0 deletions python/paddle/distribution/chi2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.base.data_feeder import check_type, convert_dtype
from paddle.base.framework import Variable
from paddle.distribution.gamma import Gamma
from paddle.framework import in_dynamic_mode

__all__ = ["Chi2"]


class Chi2(Gamma):
r"""
Creates a Chi-squared distribution parameterized by shape parameter.
This is exactly equivalent to Gamma(concentration=0.5*df, rate=0.5), :ref:`api_paddle_distribution_Gamma`.

Args:
df (float or Tensor): shape parameter of the distribution

Example:
.. code-block:: python

>>> import paddle
>>> m = paddle.distribution.Chi2(paddle.to_tensor([1.0]))
>>> sample = m.sample()
>>> sample.shape
[1]

"""

def __init__(self, df):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缺少对输入的类型检查,API说明可以支持float或Tensor,但是如果输入为一个int,当前逻辑会被直接转换成一个Tensor,这也是正确的?修复完成后也补上对应单测(包括lkj)

if not in_dynamic_mode():
check_type(
df,
'df',
(float, Variable),
'Chi2',
)

# Get/convert concentration to tensor.
if self._validate_args(df):
self.df = df
self.dtype = convert_dtype(df.dtype)
else:
[self.df] = self._to_tensor(df)
self.dtype = paddle.get_default_dtype()

self.rate = paddle.full_like(self.df, 0.5)

if not paddle.all(self.df > 0):
raise ValueError("The arg of `df` must be positive.")

super().__init__(self.df * 0.5, self.rate)
Loading