Skip to content

Conversation

@zou3519
Copy link
Collaborator

@zou3519 zou3519 commented May 2, 2022

There's a bug in simple functorch's custom_vjp. In particular, we would
like to learn what it would take to get simple functorch's custom_vjp to
work like JAX's custom_vjp w.r.t. to the behavior towards intermediate
Tensors.

TODO: we should try to fix the bugs mentioned.

Comment on lines +1323 to +1347
"""
import jax
import jax.numpy as jnp
x = jnp.array(0.123)
@jax.custom_vjp
def jax_square(x):
return None
def f_fwd(x):
two_x = 2 * x
return x ** 2, (two_x,)
def f_bwd(saved, grad_output):
two_x, = saved
return (grad_output * two_x,)
jax_square.defvjp(f_fwd, f_bwd)
# This is 2.0 which is correct
ddx = jax.grad(jax.grad(jax_square))(x)
print(ddx)
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "reference". Note that JAX's f_fwd returns an intermediate, two_x and uses it in the gradient computation. Computing the second order gradient is correct.

Comment on lines +1316 to +1321
# Bug 0:
# This is wrong: The result is None, but it should be 2.
# Somehow the gradients aren't getting recorded.
ddx = run_gradgrad(outer, inner, x)
import pdb; pdb.set_trace()
Copy link
Collaborator Author

@zou3519 zou3519 May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug 0: ddx should equal two, but it is actually None here. This behavior is in-line with with autograd.Function does today, actually. autograd.Function requires the user to specify a gradient formula for the intermediate (two_x).

However, we're trying to explore what it would take to not require the user to specify the gradient formula, so let's not end this discussion at "this is expected".

Comment on lines +1306 to +1307
# print(outer.gradient_tape)
# Bug 1: outer.grad still extend outer.gradient_tape, even though create_graph is False.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug 1, which is probably related

There's a bug in simple functorch's custom_vjp. In particular, we would
like to learn what it would take to get simple functorch's custom_vjp to
work like JAX's custom_vjp w.r.t. to the behavior towards intermediate
Tensors.

TODO: we should try to fix the bugs mentioned.
Comment on lines +1309 to +1314
# Bug 2: inner.gradient_tape
# The second TapeEntry doesn't use x at all! In fact, no tapes
# capture the 2 * x behavior.
# [TapeEntry(inputs=['x'], outputs=['v199'], propagate=<function Autograd.custom_vjp.<locals>
# .propagate at 0x7f8310658a60>), TapeEntry(inputs=['v200', 'v198'], outputs=['v201'], propagate
# =<function Autograd.mul.<locals>.propagate at 0x7f83106a3dc0>)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug 2

@albanD albanD self-requested a review May 4, 2022 20:32
@samdow samdow mentioned this pull request Jun 14, 2022
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant