Skip to content

Functorch gradients: investigation and fix #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion opacus/grad_sample/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def ft_compute_per_sample_gradient(layer, activations, backprops):
activations: the input to the layer
backprops: the gradient of the loss w.r.t. outputs of the layer
"""
parameters = list(layer.parameters())
parameters = list(layer.parameters(recurse=True))
if not hasattr(layer, "ft_compute_sample_grad"):
prepare_layer(layer)

Expand Down
20 changes: 18 additions & 2 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import logging
import warnings
from functools import partial
from typing import List, Tuple
from typing import Iterable, List, Tuple

import torch
import torch.nn as nn
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer
from opacus.grad_sample.gsm_base import AbstractGradSampleModule
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear
from opacus.utils.module_utils import (
has_trainable_params,
requires_grad,
trainable_modules,
trainable_parameters,
Expand Down Expand Up @@ -146,6 +147,21 @@ def __init__(
def forward(self, *args, **kwargs):
return self._module(*args, **kwargs)

def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]:
if has_trainable_params(module):
yield module

# Don't recurse if module is handled by functorch
if (
has_trainable_params(module)
and type(module) not in self.GRAD_SAMPLERS
and type(module) not in [DPRNN, DPLSTM, DPGRU]
):
return

for m in module.children():
yield from self.iterate_submodules(m)

def add_hooks(
self,
*,
Expand Down Expand Up @@ -177,7 +193,7 @@ def add_hooks(
self._module.autograd_grad_sample_hooks = []
self.autograd_grad_sample_hooks = self._module.autograd_grad_sample_hooks

for _module_name, module in trainable_modules(self._module):
for module in self.iterate_submodules(self._module):
# Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear`
if type(module) in [DPRNN, DPLSTM, DPGRU]:
continue
Expand Down
85 changes: 71 additions & 14 deletions opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@
from torchvision import models, transforms
from torchvision.datasets import FakeData

from .utils import CustomLinearModule, LinearWithExtraParam


def _is_functorch_available():
try:
# flake8: noqa F401
import functorch

return True
except ImportError:
return False


def get_grad_sample_aggregated(tensor: torch.Tensor, loss_type: str = "mean"):
if tensor.grad_sample is None:
Expand Down Expand Up @@ -246,7 +258,7 @@ def _compare_to_vanilla(
# vanilla gradient is nearly zero: will match even with clipping
continue

atol = 1e-7 if max_steps == 1 else 1e-5
atol = 1e-7 if max_steps == 1 else 1e-4
self.assertEqual(
torch.allclose(vp, pp, atol=atol, rtol=1e-3),
expected_match,
Expand All @@ -265,10 +277,6 @@ def _compare_to_vanilla(
do_noise=st.booleans(),
use_closure=st.booleans(),
max_steps=st.sampled_from([1, 4]),
# do_clip=st.just(False),
# do_noise=st.just(False),
# use_closure=st.just(False),
# max_steps=st.sampled_from([4]),
)
@settings(deadline=None)
def test_compare_to_vanilla(
Expand Down Expand Up @@ -799,9 +807,7 @@ def _init_data(self):
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
def _init_model(self):
return SampleConvNet()


Expand All @@ -817,16 +823,21 @@ def _init_data(self):
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
def _init_model(self):
m = SampleConvNet()
for p in itertools.chain(m.conv1.parameters(), m.gnorm1.parameters()):
p.requires_grad = False

return m


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineConvNetFrozenTestFunctorch(PrivacyEngineConvNetFrozenTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


@unittest.skipIf(
torch.__version__ < API_CUTOFF_VERSION, "not supported in this torch version"
)
Expand All @@ -840,6 +851,13 @@ def test_sample_grad_aggregation(self):
pass


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineConvNetTestFunctorch(PrivacyEngineConvNetTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


class SampleAttnNet(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -919,6 +937,13 @@ def _init_model(
return SampleAttnNet()


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineTextTestFunctorch(PrivacyEngineTextTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


class SampleTiedWeights(nn.Module):
def __init__(self, tie=True):
super().__init__()
Expand Down Expand Up @@ -958,7 +983,39 @@ def _init_data(self):
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
def _init_model(self):
return SampleTiedWeights(tie=True)


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineTiedWeightsTestFunctorch(PrivacyEngineTiedWeightsTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


class ModelWithCustomLinear(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = CustomLinearModule(4, 8)
self.fc2 = LinearWithExtraParam(8, 4)
self.extra_param = nn.Parameter(torch.randn(4, 4))

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = x.matmul(self.extra_param)
return x


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineCustomLayerTest(BasePrivacyEngineTest, unittest.TestCase):
def _init_data(self):
ds = TensorDataset(
torch.randn(self.DATA_SIZE, 4),
torch.randint(low=0, high=3, size=(self.DATA_SIZE,)),
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(self):
return ModelWithCustomLinear()
54 changes: 6 additions & 48 deletions opacus/tests/privacy_engine_validation_test.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,16 @@
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
from opacus import PrivacyEngine
from opacus.grad_sample.gsm_exp_weights import API_CUTOFF_VERSION
from torch.utils.data import DataLoader


class BasicSupportedModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2)
self.gn = nn.GroupNorm(num_groups=2, num_channels=8)
self.fc = nn.Linear(in_features=4, out_features=8)
self.ln = nn.LayerNorm([8, 8])

def forward(self, x):
x = self.conv(x)
x = self.gn(x)
x = self.fc(x)
x = self.ln(x)
return x


class CustomLinearModule(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self._weight = nn.Parameter(torch.randn(out_features, in_features))
self._bias = nn.Parameter(torch.randn(out_features))

def forward(self, x):
return F.linear(x, self._weight, self._bias)


class MatmulModule(nn.Module):
def __init__(self, input_features, output_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(input_features, output_features))

def forward(self, x):
return torch.matmul(x, self.weight)


class LinearWithExtraParam(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.extra_param = nn.Parameter(torch.randn(out_features, 2))

def forward(self, x):
x = self.fc(x)
x = x.matmul(self.extra_param)
return x
from .utils import (
BasicSupportedModule,
CustomLinearModule,
LinearWithExtraParam,
MatmulModule,
)


class PrivacyEngineValidationTest(unittest.TestCase):
Expand Down
50 changes: 50 additions & 0 deletions opacus/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicSupportedModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2)
self.gn = nn.GroupNorm(num_groups=2, num_channels=8)
self.fc = nn.Linear(in_features=4, out_features=8)
self.ln = nn.LayerNorm([8, 8])

def forward(self, x):
x = self.conv(x)
x = self.gn(x)
x = self.fc(x)
x = self.ln(x)
return x


class CustomLinearModule(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self._weight = nn.Parameter(torch.randn(out_features, in_features))
self._bias = nn.Parameter(torch.randn(out_features))

def forward(self, x):
return F.linear(x, self._weight, self._bias)


class MatmulModule(nn.Module):
def __init__(self, input_features: int, output_features: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(input_features, output_features))

def forward(self, x):
return torch.matmul(x, self.weight)


class LinearWithExtraParam(nn.Module):
def __init__(self, in_features: int, out_features: int, hidden_dim: int = 8):
super().__init__()
self.fc = nn.Linear(in_features, hidden_dim)
self.extra_param = nn.Parameter(torch.randn(hidden_dim, out_features))

def forward(self, x):
x = self.fc(x)
x = x.matmul(self.extra_param)
return x
6 changes: 5 additions & 1 deletion opacus/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
logger.setLevel(level=logging.INFO)


def parametrized_modules(module: nn.Module) -> Iterable[nn.Module]:
def has_trainable_params(module: nn.Module) -> bool:
return any(p.requires_grad for p in module.parameters(recurse=False))


def parametrized_modules(module: nn.Module) -> Iterable[Tuple[str, nn.Module]]:
"""
Recursively iterates over all submodules, returning those that
have parameters (as opposed to "wrapper modules" that just organize modules).
Expand Down