-
Notifications
You must be signed in to change notification settings - Fork 349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
functorch mode computes the forward twice #521
Comments
My understanding is For arbitrary user modules, we don't know what the backward pass is, so we cannot directly vmap over it. functorch doesn't know what the backward pass is until the forward pass gets executed using functorch, so it needs to recompute the forward pass. I'm wondering if there's a way to avoid computing the forward pass twice (once in the original execution of the module, and once in the grad_sample computation) |
I believe something like the following works, though functorch does not officially support it (yet). The idea is:
import torch
import torch.nn.functional as F
from functorch import vmap, vjp, grad
import functools
torch.manual_seed(0)
x = torch.randn(2, 3)
w = torch.randn(3, 3)
fn = torch.matmul
backdoor = []
def sample_grad_call(fn, x, w):
def inner(x, w):
res, vjp_fn = vjp(functools.partial(fn, x), w)
backdoor.append(vjp_fn)
return res
res = vmap(inner, (0, None))(x, w)
return res
def compute_grad_sample(grad_out):
def inner(grad_out, dummy):
return backdoor[-1](grad_out)
grad_sample = vmap(inner)(grad_out, x)
return grad_sample
# Somehow replace the module forward pass with the following
y = sample_grad_call(fn, x, w)
grad_y = torch.ones_like(y)
# And then replace the module backward pass with the following
result, = compute_grad_sample(grad_y)
# Here's a correctness check
w.requires_grad_()
expected0, = torch.autograd.grad(fn(x[0], w).sum(), w)
expected1, = torch.autograd.grad(fn(x[1], w).sum(), w)
expected = torch.stack([expected0, expected1]) On the functorch side, we'll try to hack up a POC integrating this approach with Opacus. |
Thanks @zou3519 for jumping in. I second this:
For linear modules we use einsums (either directly or through ExpandedWeights). Note that it is also possible to have the entire model be functional and pass the grad samples to Opacus (it is the "no_op" GradSampleModule). |
In
functorch
mode, Opacus saves the activations from previous layers and uses to compute the gradient per sample with functorch. However,functorch.grad
ends up doing forward on the layer and a backward. This hampers the performance when usingfunctorch
mode.Ref to the code which uses
functorch.grad
:opacus/opacus/grad_sample/functorch.py
Lines 36 to 40 in a6c2567
opacus/opacus/grad_sample/functorch.py
Line 55 in a6c2567
We can apply the following patch (for Linear) to improve the perf when using
functorch
mode. The idea is similar to what is done with hooks approach.Before Patch
After Patch
Benchmark Script
cc: @zou3519
The text was updated successfully, but these errors were encountered: