-
Notifications
You must be signed in to change notification settings - Fork 29
custom_vjp bug #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
custom_vjp bug #31
Conversation
| """ | ||
| import jax | ||
| import jax.numpy as jnp | ||
| x = jnp.array(0.123) | ||
| @jax.custom_vjp | ||
| def jax_square(x): | ||
| return None | ||
| def f_fwd(x): | ||
| two_x = 2 * x | ||
| return x ** 2, (two_x,) | ||
| def f_bwd(saved, grad_output): | ||
| two_x, = saved | ||
| return (grad_output * two_x,) | ||
| jax_square.defvjp(f_fwd, f_bwd) | ||
| # This is 2.0 which is correct | ||
| ddx = jax.grad(jax.grad(jax_square))(x) | ||
| print(ddx) | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the "reference". Note that JAX's f_fwd returns an intermediate, two_x and uses it in the gradient computation. Computing the second order gradient is correct.
| # Bug 0: | ||
| # This is wrong: The result is None, but it should be 2. | ||
| # Somehow the gradients aren't getting recorded. | ||
| ddx = run_gradgrad(outer, inner, x) | ||
| import pdb; pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug 0: ddx should equal two, but it is actually None here. This behavior is in-line with with autograd.Function does today, actually. autograd.Function requires the user to specify a gradient formula for the intermediate (two_x).
However, we're trying to explore what it would take to not require the user to specify the gradient formula, so let's not end this discussion at "this is expected".
| # print(outer.gradient_tape) | ||
| # Bug 1: outer.grad still extend outer.gradient_tape, even though create_graph is False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug 1, which is probably related
There's a bug in simple functorch's custom_vjp. In particular, we would like to learn what it would take to get simple functorch's custom_vjp to work like JAX's custom_vjp w.r.t. to the behavior towards intermediate Tensors. TODO: we should try to fix the bugs mentioned.
| # Bug 2: inner.gradient_tape | ||
| # The second TapeEntry doesn't use x at all! In fact, no tapes | ||
| # capture the 2 * x behavior. | ||
| # [TapeEntry(inputs=['x'], outputs=['v199'], propagate=<function Autograd.custom_vjp.<locals> | ||
| # .propagate at 0x7f8310658a60>), TapeEntry(inputs=['v200', 'v198'], outputs=['v201'], propagate | ||
| # =<function Autograd.mul.<locals>.propagate at 0x7f83106a3dc0>)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug 2
There's a bug in simple functorch's custom_vjp. In particular, we would
like to learn what it would take to get simple functorch's custom_vjp to
work like JAX's custom_vjp w.r.t. to the behavior towards intermediate
Tensors.
TODO: we should try to fix the bugs mentioned.