-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【Hackathon 6th No.5】Add chi2/LKJCholesky API to Paddle -part #63883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
d4747dd
update chi2 && lkj cholesky
cmcamdy 8009776
update lkj
cmcamdy c298011
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy 190d0d8
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy c131dae
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy 4c2ddfd
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy e818d4b
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 9746325
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 2770dd8
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 7761097
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 791592f
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 49f220e
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy d77527e
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 8d4d74d
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy aa1ebbc
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 9caf814
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy 1269701
【Hackathon 6th No.5】Chi2 / LKJCholesky API
cmcamdy f1200d8
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy 4b93065
fix
cmcamdy 2416887
fix test
cmcamdy a53b5b7
fix
cmcamdy 17468a1
fix
cmcamdy 7f5e0a4
fix
cmcamdy bb3d200
fix
cmcamdy fa1a604
fix
cmcamdy 3114f6a
fix doc
cmcamdy 4c64aac
fix
cmcamdy 344548f
fix
cmcamdy bb4413f
fix lkj
cmcamdy 1c8ad85
fix
cmcamdy a876596
fix test
cmcamdy 1a9836b
fix
cmcamdy d21c83d
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy f55797a
Merge branch 'student_t' of github.com:cmcamdy/Paddle into student_t
cmcamdy f5bb90e
fix
cmcamdy 9277759
fix
cmcamdy f2a894f
fix
cmcamdy 6a46e32
fix
cmcamdy 9493476
fix
cmcamdy 3e4a35f
fix
cmcamdy 610adb0
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy 65f6cdb
fix doc
cmcamdy 85de0b3
add des & set timeout
cmcamdy 95d221c
fix
cmcamdy a28babe
fix
cmcamdy 84dec94
Merge remote-tracking branch 'upstream/develop' into student_t
cmcamdy d64ed4d
add type check
cmcamdy 8a6b071
fix
cmcamdy aae4350
add bad params && shapes
cmcamdy 12b487a
fix chi2
cmcamdy 790a834
fix chi2
cmcamdy 3c4ef82
fix
cmcamdy 5c23f33
fix
cmcamdy fe3edf2
modify chi2
cmcamdy 386519d
Update python/paddle/distribution/chi2.py
cmcamdy b7cdcc6
Update python/paddle/distribution/chi2.py
cmcamdy 23c87e7
merge develop
cmcamdy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import paddle | ||
| from paddle.base.data_feeder import check_type, convert_dtype | ||
| from paddle.base.framework import Variable | ||
| from paddle.distribution.gamma import Gamma | ||
| from paddle.framework import in_dynamic_mode | ||
|
|
||
| __all__ = ["Chi2"] | ||
|
|
||
|
|
||
| class Chi2(Gamma): | ||
| r""" | ||
| Creates a Chi-squared distribution parameterized by shape parameter. | ||
| This is exactly equivalent to Gamma(concentration=0.5*df, rate=0.5), :ref:`api_paddle_distribution_Gamma`. | ||
|
|
||
| Args: | ||
| df (float or Tensor): shape parameter of the distribution | ||
|
|
||
| Example: | ||
| .. code-block:: python | ||
|
|
||
| >>> import paddle | ||
| >>> m = paddle.distribution.Chi2(paddle.to_tensor([1.0])) | ||
| >>> sample = m.sample() | ||
| >>> sample.shape | ||
| [1] | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, df): | ||
| if not in_dynamic_mode(): | ||
| check_type( | ||
| df, | ||
| 'df', | ||
| (float, Variable), | ||
| 'Chi2', | ||
| ) | ||
|
|
||
| # Get/convert concentration to tensor. | ||
| if self._validate_args(df): | ||
| self.df = df | ||
| self.dtype = convert_dtype(df.dtype) | ||
| else: | ||
| [self.df] = self._to_tensor(df) | ||
| self.dtype = paddle.get_default_dtype() | ||
|
|
||
| self.rate = paddle.full_like(self.df, 0.5) | ||
|
|
||
| if not paddle.all(self.df > 0): | ||
| raise ValueError("The arg of `df` must be positive.") | ||
|
|
||
| super().__init__(self.df * 0.5, self.rate) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
缺少对输入的类型检查,API说明可以支持float或Tensor,但是如果输入为一个int,当前逻辑会被直接转换成一个Tensor,这也是正确的?修复完成后也补上对应单测(包括lkj)