-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add more idata attributes for JAX samplers #7360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Added trailing commas for formatting. |
@@ -44,6 +44,10 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): | |||
idata1 = sample(**kwargs) | |||
idata2 = sample(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not introduced by you, but why are we sampling twice, is idata2 used anywhere?
@@ -64,8 +68,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): | |||
assert "L" in idata1.observed_data | |||
assert idata1.posterior.chain.size == 2 | |||
assert idata1.posterior.draw.size == 500 | |||
assert idata1.posterior.tuning_steps == 500 | |||
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 idata2
is used in this line only, to check if the sampling is deterministic I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay! Thanks for checking
Description
This PR provides a fix for #7262 by filling in missing
post_ attributes
keyword argument.Related Issue
Checklist
Type of change