Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions simple_functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,3 +1277,72 @@ def run_gradvmap(d2: "Batched", d1: "Autograd"):


run_gradvmap(d2, d1)


def square_fwd(d, x):
intermediate = d.mul(x, label(torch.tensor(2.), "two"))
return d.mul(x, x), intermediate,

def square_bwd(d, gradOutputs, intermediate):
(gO,) = gradOutputs
return [d.mul(gO, intermediate)]

def square_with_custom_vjp(d, x):
result, _ = d.custom_vjp(square_fwd, square_bwd, x)
return result

def simple_square(d, x):
return d.mul(x, x)


inner = Autograd(Logger(Torch(), name="Torch"), name="Autograd1", create_graph=False)
outer = Autograd(inner, name="Autograd2", create_graph=False)
x = label(torch.rand([]), "x")

def run_gradgrad(outer, inner, x):
y = square_with_custom_vjp(outer, x)
import pdb; pdb.set_trace()
dx, = outer.grad(y, [x])
# print(outer.gradient_tape)
# Bug 1: outer.grad still extend outer.gradient_tape, even though create_graph is False.
Comment on lines +1306 to +1307
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug 1, which is probably related

ddx, = inner.grad(dx, [x])
# Bug 2: inner.gradient_tape
# The second TapeEntry doesn't use x at all! In fact, no tapes
# capture the 2 * x behavior.
# [TapeEntry(inputs=['x'], outputs=['v199'], propagate=<function Autograd.custom_vjp.<locals>
# .propagate at 0x7f8310658a60>), TapeEntry(inputs=['v200', 'v198'], outputs=['v201'], propagate
# =<function Autograd.mul.<locals>.propagate at 0x7f83106a3dc0>)]
Comment on lines +1309 to +1314
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug 2

return ddx

# Bug 0:
# This is wrong: The result is None, but it should be 2.
# Somehow the gradients aren't getting recorded.
ddx = run_gradgrad(outer, inner, x)
import pdb; pdb.set_trace()
Comment on lines +1317 to +1321
Copy link
Collaborator Author

@zou3519 zou3519 May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug 0: ddx should equal two, but it is actually None here. This behavior is in-line with with autograd.Function does today, actually. autograd.Function requires the user to specify a gradient formula for the intermediate (two_x).

However, we're trying to explore what it would take to not require the user to specify the gradient formula, so let's not end this discussion at "this is expected".


# Here's the equivalent JAX code. Uncomment it out to see it work.
"""
import jax
import jax.numpy as jnp

x = jnp.array(0.123)

@jax.custom_vjp
def jax_square(x):
return None

def f_fwd(x):
two_x = 2 * x
return x ** 2, (two_x,)

def f_bwd(saved, grad_output):
two_x, = saved
return (grad_output * two_x,)

jax_square.defvjp(f_fwd, f_bwd)

# This is 2.0 which is correct
ddx = jax.grad(jax.grad(jax_square))(x)
print(ddx)
"""
Comment on lines +1324 to +1347
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "reference". Note that JAX's f_fwd returns an intermediate, two_x and uses it in the gradient computation. Computing the second order gradient is correct.