-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transform jax samples #4427
Transform jax samples #4427
Conversation
Probably dumb question, but can't we convert all transformed variables at the end of sampling with vectorized operations? |
Hey Ricardo, in principle I think that's a good idea, but at least with this approach, calling the function returned by
As mentioned in the pull request, if we could get a JAX version of the function, we could just use |
+1. We should jaxify the theano function, and then vmap it. |
To be more specific, we should have only 1 for loop to loop over all the RVs, and if they are transformed, grab the forward function and compile it into jax function, then call |
I agree that this would be the best solution. Here's what happens if I naively try: jax_funcify(model.fastfn(var_names)) gives:
Also: jax_funcify(model.fastfn) gives
|
Yep, sounds good. Only issue: would this cover deterministic variables too? If not it might be better to try to use the jax.vmap(jax.vmap(fun))(samples) over the dict of samples and it should all work just fine (but maybe I'm wrong about that). |
Oh you are right. |
Sounds great, thanks, I'll look into that! |
Hi all, please take a look at the latest version, which uses As a side note (it's not an issue with the code here): I did run into slightly odd behaviour which I thought I'd mention. When exploring how jax_funcify worked on the LKJ example: graphs = {x.name: theano.graph.fg.FunctionGraph(model.free_RVs, [x]) for x in model.unobserved_RVs}
jax_fns = {x: jax_funcify(y) for x, y in graphs.items()} I got:
What I thought was a little odd here is the first two: these two are the inputs, and the jaxified function is just empty. I'm guessing this is related to the fact that no computation is required for them, since they are already in the inputs, but I would have expected some kind of identity function or something. For the code in the pull request, I avoid this issue by first working out which RVs actually have to be computed: free_rv_names = {x.name for x in model.free_RVs}
unobserved_names = {x.name for x in model.unobserved_RVs}
names_to_compute = unobserved_names - free_rv_names
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute] and then only computing those, using Looking forward to hearing your thoughts and whether anything could be improved. |
Hi all, I've made one further change, and I think I may have cluttered everything by rebasing with upstream, I'm really sorry about that, I'm very new to all of this. I hope I haven't made things too horrible; please let me know if there's anything I should do to neaten up this pull request again. The one additional change I have made is this commit here: What this change does is use: pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed) , where Once again, sorry about the clutter, let me know if there's something I should and of course if there is anything else to improve with this pull request. |
Hi @martiningram ,
When you consider it done, make sure to add a line to the release notes. |
* Add `pymc3.sampling_jax._transform_samples` function which transforms draws * Modify `pymc3.sampling_jax.sample_numpyro_nuts` function to use this function to return transformed samples * Add release note
15eb620
to
5535721
Compare
Thanks a lot for your help @michaelosthege ! I think I've managed to follow your xkcd strategy, and I've added a line to the release notes, too. Let me know what you think! |
Could you add a small test? A Normal likelihood with |
Sure, will do! I'm on holiday right now but should have something by early next week at the latest. |
Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
Hi @junpenglao , I've tried my hand at adding a small test. The test checks that the transformation from the log scale works correctly. Please let me know if I should change anything, or whether this is roughly what you had in mind! Note that to make this work, I had to add an argument |
It looks like the test in the CI failes due to the lack of |
I think it's fine to install it for CI only. We'll probably add more jax
tests anyway.
…On Mon, Feb 8, 2021, 19:58 Michael Osthege ***@***.***> wrote:
It looks like the test in the CI failes due to the lack of jax.
@twiecki <https://github.com/twiecki> @junpenglao
<https://github.com/junpenglao> how should we proceed? Should we install
jax into the CI environment, oder leave the test out of the CI ?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#4427 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGCMU5VNITPBY7QB4SDS6AX4NANCNFSM4WL3HX6Q>
.
|
CC @MarcoGorelli on how to install JAX only for CI. |
I think in the |
@twiecki @MarcoGorelli as you can see above I tried to add the jax depency, but it didn't work. I would really like to get |
in blackjax jax is pinned very narrowly (I think Remi said it changes very frequently), does that need doing here too? |
which job are you referring to? Just had a look at the top one and it shows
|
I looked at the Windows job, because that one was ❌ and Adding |
@michaelosthege I get your point and I do agree. However, just by the nature of how important the JAX stuff is to our future I'd like us to not put any boundaries here. I heard from some users that could really use this as a replacement for Until the blackjax samplers are more mature or we find a better solution I would recommend we bite the bullet and add it as a dependency here. |
Alternative idea: We split the Jax tests into their own CI job and install NumPyro just in that CI job. Then potential failures are at least separate from the rest. |
c54acbe
to
a7ae6d4
Compare
a7ae6d4
to
07b715c
Compare
This is the way. |
Thanks @martiningram and everyone else! |
Hi there,
This pull request addresses the following issue: #4415
It uses the model's
fastfn
function to compute the values of all the variables listed inmodel.unobserved_RVs
. This should include the transformed variables as well as any deterministic variables of interest. It's also what I believe is done in the pymc3 samplers currently -- it's a bit hidden but I believe it is done here.A note is that this implementation loops over chains and samples and is thus not particularly efficient. I have added a timing print statement to the code to easily see this. I ran it on the LKJ example and sampling took 20s (500 warmup, 500 sampling, 4 chains), transforming took 7s. A cool improvement would be to somehow turn the theano
fastfn
into a JAX function, which could then be evaluated much more efficiently usingjax.vmap
across the samples, but I didn't see an easy way tojax_funcify
this function (just callingjax_funcify
doesn't work). If someone knows how, I am happy to update the code.Interested to hear your thoughts! I'm also planning to add an example notebook soon to show what this does.