-
Couldn't load subscription status.
- Fork 29
custom_vjp bug #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
custom_vjp bug #31
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1277,3 +1277,72 @@ def run_gradvmap(d2: "Batched", d1: "Autograd"): | |
|
|
||
|
|
||
| run_gradvmap(d2, d1) | ||
|
|
||
|
|
||
| def square_fwd(d, x): | ||
| intermediate = d.mul(x, label(torch.tensor(2.), "two")) | ||
| return d.mul(x, x), intermediate, | ||
|
|
||
| def square_bwd(d, gradOutputs, intermediate): | ||
| (gO,) = gradOutputs | ||
| return [d.mul(gO, intermediate)] | ||
|
|
||
| def square_with_custom_vjp(d, x): | ||
| result, _ = d.custom_vjp(square_fwd, square_bwd, x) | ||
| return result | ||
|
|
||
| def simple_square(d, x): | ||
| return d.mul(x, x) | ||
|
|
||
|
|
||
| inner = Autograd(Logger(Torch(), name="Torch"), name="Autograd1", create_graph=False) | ||
| outer = Autograd(inner, name="Autograd2", create_graph=False) | ||
| x = label(torch.rand([]), "x") | ||
|
|
||
| def run_gradgrad(outer, inner, x): | ||
| y = square_with_custom_vjp(outer, x) | ||
| import pdb; pdb.set_trace() | ||
| dx, = outer.grad(y, [x]) | ||
| # print(outer.gradient_tape) | ||
| # Bug 1: outer.grad still extend outer.gradient_tape, even though create_graph is False. | ||
| ddx, = inner.grad(dx, [x]) | ||
| # Bug 2: inner.gradient_tape | ||
| # The second TapeEntry doesn't use x at all! In fact, no tapes | ||
| # capture the 2 * x behavior. | ||
| # [TapeEntry(inputs=['x'], outputs=['v199'], propagate=<function Autograd.custom_vjp.<locals> | ||
| # .propagate at 0x7f8310658a60>), TapeEntry(inputs=['v200', 'v198'], outputs=['v201'], propagate | ||
| # =<function Autograd.mul.<locals>.propagate at 0x7f83106a3dc0>)] | ||
|
Comment on lines
+1309
to
+1314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug 2 |
||
| return ddx | ||
|
|
||
| # Bug 0: | ||
| # This is wrong: The result is None, but it should be 2. | ||
| # Somehow the gradients aren't getting recorded. | ||
| ddx = run_gradgrad(outer, inner, x) | ||
| import pdb; pdb.set_trace() | ||
|
Comment on lines
+1317
to
+1321
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug 0: ddx should equal two, but it is actually None here. This behavior is in-line with with autograd.Function does today, actually. autograd.Function requires the user to specify a gradient formula for the intermediate (two_x). However, we're trying to explore what it would take to not require the user to specify the gradient formula, so let's not end this discussion at "this is expected". |
||
|
|
||
| # Here's the equivalent JAX code. Uncomment it out to see it work. | ||
| """ | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| x = jnp.array(0.123) | ||
|
|
||
| @jax.custom_vjp | ||
| def jax_square(x): | ||
| return None | ||
|
|
||
| def f_fwd(x): | ||
| two_x = 2 * x | ||
| return x ** 2, (two_x,) | ||
|
|
||
| def f_bwd(saved, grad_output): | ||
| two_x, = saved | ||
| return (grad_output * two_x,) | ||
|
|
||
| jax_square.defvjp(f_fwd, f_bwd) | ||
|
|
||
| # This is 2.0 which is correct | ||
| ddx = jax.grad(jax.grad(jax_square))(x) | ||
| print(ddx) | ||
| """ | ||
|
Comment on lines
+1324
to
+1347
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the "reference". Note that JAX's f_fwd returns an intermediate, two_x and uses it in the gradient computation. Computing the second order gradient is correct. |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug 1, which is probably related