Skip to content

Commit

Permalink
Throws an error when params in optimizer are not the same as that of …
Browse files Browse the repository at this point in the history
…module's in `make_private` (#439)

Summary:
Pull Request resolved: #439

Compare nn.Module.parameters() with list of parameters from all param_groups of optimizer. If they are all not equal then raise error "Module parameters are different than optimizer Parameters"

Differential Revision: D37163873

fbshipit-source-id: 96cd2bc3ce9a136e450d5927b602b82f8072af9c
  • Loading branch information
Deepak Agrawal authored and facebook-github-bot committed Jun 17, 2022
1 parent d079ffd commit 3ed0b5d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
17 changes: 16 additions & 1 deletion opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import os
import warnings
from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, IO, List, Optional, Tuple, Union

import torch
from opacus.accountants import create_accountant
Expand Down Expand Up @@ -360,6 +360,21 @@ def make_private(
if noise_generator and self.secure_mode:
raise ValueError("Passing seed is prohibited in secure mode")

# compare module parameter with optimizer parameters
if not all(
torch.eq(i, j).all()
for i, j in zip(
list(module.parameters()),
sum(
[param_group["params"] for param_group in optimizer.param_groups],
[],
),
)
):
raise ValueError(
"Module parameters are different than optimizer Parameters"
)

distributed = isinstance(module, (DPDDP, DDP))

module = self._prepare_model(
Expand Down
38 changes: 37 additions & 1 deletion opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from opacus import PrivacyEngine
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
from opacus.optimizers.optimizer import _generate_noise
from opacus.scheduler import StepNoise, _NoiseScheduler
from opacus.scheduler import _NoiseScheduler, StepNoise
from opacus.utils.module_utils import are_state_dict_equal
from opacus.validators import ModuleValidator
from opacus.validators.errors import UnsupportedModuleError
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import models, transforms
Expand Down Expand Up @@ -464,6 +465,41 @@ def test_deterministic_run(self):
"Model parameters after deterministic run must match",
)

def test_param_equal_module_optimizer(self):
"""Test that the privacy engine raises error if nn.Module parameters are not equal to optimizer parameters"""
model = models.densenet121(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(nn.Linear(num_ftrs, 10), nn.Sigmoid())
optimizer = torch.optim.SGD(
model.parameters(), lr=0.01, momentum=0, weight_decay=0
)
dl = self._init_data()
model = ModuleValidator.fix(model)
privacy_engine = PrivacyEngine()
with self.assertRaisesRegex(
ValueError, "Module parameters are different than optimizer Parameters"
):
_, _, _ = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=dl,
noise_multiplier=1.1,
max_grad_norm=1.0,
)

# if optimizer is defined after ModuleValidator.fix() then raise no error
optimizer = torch.optim.SGD(
model.parameters(), lr=0.01, momentum=0, weight_decay=0
)
_, _, _ = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=dl,
noise_multiplier=1.1,
max_grad_norm=1.0,
)
self.assertTrue(1, 1)

@given(noise_scheduler=st.sampled_from([None, StepNoise]))
@settings(deadline=None)
def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):
Expand Down

0 comments on commit 3ed0b5d

Please sign in to comment.