Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add autograd and jax support #88

Merged
merged 6 commits into from
Jul 3, 2019
Merged

add autograd and jax support #88

merged 6 commits into from
Jul 3, 2019

Conversation

jcmgray
Copy link
Collaborator

@jcmgray jcmgray commented Jun 28, 2019

Description

This adds support for using autograd and jax. These both allow automatically computing gradients of functions with scalar output. jax additionally compiles the computations for GPU and faster runtime.

Essentially all that was required was just adding the aliases {'jax': 'jax.numpy', 'autograd': 'autograd.numpy'} to backend dispatch mechanism but I've also added explicit tests and jax as a compiled expression backend for numpy arrays (since it seems to have very good performance).

Finally, I've beefed up the backend tests a bit and fixed an issue with tensorflow on travis to make sure those tests are not skipped.

Examples

Compute the gradient of an arbitrary tensor contraction:

import numpy as np
import opt_einsum as oe
from autograd import grad

eq = 'ij,jk,ki->'
shapes = [(2, 3), (3, 4), (4, 2)]


def foo(arrays):
    return oe.contract(eq, *arrays)

# this function now computes the gradient (jacobian) of foo
dfoo = grad(foo)

>>> arrays = [np.random.rand(*s) for s in shapes]
>>> foo(arrays)
array(2.97556246)

>>> dfoo(arrays)
[array([[1.03736352, 1.69989611, 0.99103317],
        [0.61444428, 0.50011247, 0.33026913]]),
 array([[0.31924389, 0.48004359, 1.00855222, 0.55368964],
        [0.3843193 , 0.57097258, 0.74804434, 0.74309046],
        [0.32953833, 0.4940919 , 0.94472643, 0.58736502]]),
 array([[0.96326285, 0.61466171],
        [0.56535028, 0.12002171],
        [0.87289287, 0.53918698],
        [0.74630816, 0.21253244]])]

Compile both the function and gradient using jax:

import jax

jit_foo = jax.jit(foo)
jit_dfoo = jax.jit(jax.grad(foo))

>>> jit_foo(arrays)
DeviceArray(2.9755626, dtype=float32)

>>> jit_dfoo(arrays)
[DeviceArray([[1.03736353, 1.6998961 , 0.9910332 ],
              [0.61444432, 0.50011247, 0.33026916]], dtype=float32),
 DeviceArray([[0.31924388, 0.48004356, 1.00855219, 0.5536896 ],
              [0.38431931, 0.57097256, 0.74804437, 0.74309039],
              [0.32953838, 0.49409193, 0.94472641, 0.58736503]],
             dtype=float32),
 DeviceArray([[0.96326286, 0.61466169],
              [0.56535023, 0.12002171],
              [0.87289286, 0.53918701],
              [0.74630815, 0.21253243]], dtype=float32)]

Or, regardless of gradients, just compile a contraction using jax (that still accepts and gives out numpy arrays):

>>> expr = oe.contract_expression('ij,jk,kl->li', *shapes)
>>> expr(*arrays, backend='jax')
array([[2.4497495, 1.1122954],
       [0.9125154, 0.525813 ]], dtype=float32)

Todos

  • Write some docs up (note somewhere that jax by default converts everything to single precision)

Status

  • Ready to go

Also update travis dist to xenial to fix glibc errors relating to tensorflow and jax imports
@codecov-io
Copy link

Codecov Report

Merging #88 into master will decrease coverage by 0.16%.
The diff coverage is 88.23%.

Copy link
Owner

@dgasmith dgasmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I see no issues here, any hindrance to merging?

@jcmgray
Copy link
Collaborator Author

jcmgray commented Jul 1, 2019

I might just add some docs when I get the chance. Maybe a bullet point on the readme (since automatic differentiation is very nice for optimization etc) and a little section in the 'backends' bit?

@dgasmith
Copy link
Owner

dgasmith commented Jul 1, 2019

Sounds good! I am trying to clear some time to get back to this and push a 3.0 out.

@jcmgray
Copy link
Collaborator Author

jcmgray commented Jul 3, 2019

OK I think all good to go from my end if the docs look good. A 3.0 release sounds great!

Copy link
Owner

@dgasmith dgasmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments, feel free to merge if you are ok with them.

This is really awesome to add to the project. Kind of amazing the underlying architecture can handle something like autograd pretty seamlessly.

array(3.71251519)

>>> # wrap foo with autograd to compute gradients instead
>>> dfoo = autograd.grad(foo)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats pretty cool that all of this works.

>>> # generate a compiled version of the gradient function
>>> jit_dfoo = jax.jit(jax.grad(foo))
>>> jit_dfoo([x, y, z])
[DeviceArray([[1.1137383 , 1.14972878, 0.64056885],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to use the same random seed above and here to show that they produce the same thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch, I thought I had updated them with the same arrays but obviously forgot. Fixed now.

@dgasmith dgasmith added this to the v3.0 milestone Jul 3, 2019
@jcmgray
Copy link
Collaborator Author

jcmgray commented Jul 3, 2019

There were some new test errors coming from pytest v5.0 that I've also just fixed.

@jcmgray jcmgray merged commit 573295e into dgasmith:master Jul 3, 2019
@mattjj
Copy link

mattjj commented Aug 20, 2019

Woo awesome!

These both allow automatically computing gradients of functions with scalar output.

Actually both Autograd and JAX do a whole lot more than that! Jacobians, Hessians, whatever you like! See for example the JAX Autodiff Cookbook (part 1).

Thanks for adding this. The future of numerical computing is bright :)

@dgasmith
Copy link
Owner

Thanks for the nice comments! Happy to take a PR that expands on what JAX can do so that we can better highlight the library.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants