-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add autograd and jax support #88
Conversation
Also update travis dist to xenial to fix glibc errors relating to tensorflow and jax imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I see no issues here, any hindrance to merging?
I might just add some docs when I get the chance. Maybe a bullet point on the readme (since automatic differentiation is very nice for optimization etc) and a little section in the 'backends' bit? |
Sounds good! I am trying to clear some time to get back to this and push a 3.0 out. |
OK I think all good to go from my end if the docs look good. A 3.0 release sounds great! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor comments, feel free to merge if you are ok with them.
This is really awesome to add to the project. Kind of amazing the underlying architecture can handle something like autograd pretty seamlessly.
array(3.71251519) | ||
|
||
>>> # wrap foo with autograd to compute gradients instead | ||
>>> dfoo = autograd.grad(foo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats pretty cool that all of this works.
docs/source/backends.rst
Outdated
>>> # generate a compiled version of the gradient function | ||
>>> jit_dfoo = jax.jit(jax.grad(foo)) | ||
>>> jit_dfoo([x, y, z]) | ||
[DeviceArray([[1.1137383 , 1.14972878, 0.64056885], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to use the same random seed above and here to show that they produce the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good catch, I thought I had updated them with the same arrays but obviously forgot. Fixed now.
There were some new test errors coming from pytest v5.0 that I've also just fixed. |
Woo awesome!
Actually both Autograd and JAX do a whole lot more than that! Jacobians, Hessians, whatever you like! See for example the JAX Autodiff Cookbook (part 1). Thanks for adding this. The future of numerical computing is bright :) |
Thanks for the nice comments! Happy to take a PR that expands on what JAX can do so that we can better highlight the library. |
Description
This adds support for using autograd and jax. These both allow automatically computing gradients of functions with scalar output.
jax
additionally compiles the computations for GPU and faster runtime.Essentially all that was required was just adding the aliases
{'jax': 'jax.numpy', 'autograd': 'autograd.numpy'}
to backend dispatch mechanism but I've also added explicit tests and jax as a compiled expression backend for numpy arrays (since it seems to have very good performance).Finally, I've beefed up the backend tests a bit and fixed an issue with tensorflow on travis to make sure those tests are not skipped.
Examples
Compute the gradient of an arbitrary tensor contraction:
Compile both the function and gradient using
jax
:Or, regardless of gradients, just compile a contraction using
jax
(that still accepts and gives out numpy arrays):Todos
jax
by default converts everything to single precision)Status