Skip to content

Commit

Permalink
Throws an error when params in optimizer are not the same as that of …
Browse files Browse the repository at this point in the history
…module's in `make_private` (pytorch#439)

Summary:
Pull Request resolved: pytorch#439

Compare nn.Module.parameters() with list of parameters from all param_groups of optimizer. If they are all not equal then raise error "Module parameters are different than optimizer Parameters"

Differential Revision: D37163873

fbshipit-source-id: 8e25fa1738f08c5aa52f856023f72948164d6f0e
  • Loading branch information
Deepak Agrawal authored and facebook-github-bot committed Jun 17, 2022
1 parent d079ffd commit bdaab58
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
4 changes: 4 additions & 0 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ def make_private(
if noise_generator and self.secure_mode:
raise ValueError("Passing seed is prohibited in secure mode")

# compare module parameter with optimizer parameters
if not all(torch.eq(i,j).all() for i,j in zip(list(module.parameters()), sum([param_group['params'] for param_group in optimizer.param_groups], []))):
raise ValueError("Module parameters are different than optimizer Parameters")

distributed = isinstance(module, (DPDDP, DDP))

module = self._prepare_model(
Expand Down
33 changes: 33 additions & 0 deletions opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.nn.functional as F
from hypothesis import given, settings
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
from opacus.optimizers.optimizer import _generate_noise
from opacus.scheduler import StepNoise, _NoiseScheduler
Expand Down Expand Up @@ -464,6 +465,38 @@ def test_deterministic_run(self):
"Model parameters after deterministic run must match",
)

def test_param_equal_module_optimizer(self):
"""Test that the privacy engine raises error if nn.Module parameters are not equal to optimizer parameters """
model = models.densenet121(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(nn.Linear(num_ftrs, 10), nn.Sigmoid())
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0,weight_decay=0)
dl = self._init_data()
model = ModuleValidator.fix(model)
privacy_engine = PrivacyEngine()
with self.assertRaisesRegex(ValueError, "Module parameters are different than optimizer Parameters"):
_, _, _ = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=dl,
noise_multiplier=1.1,
max_grad_norm=1.0
)

# if optimizer is defined after ModuleValidator.fix() then raise no error
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0,weight_decay=0)
_, _, _ = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=dl,
noise_multiplier=1.1,
max_grad_norm=1.0
)
self.assertTrue(1, 1)




@given(noise_scheduler=st.sampled_from([None, StepNoise]))
@settings(deadline=None)
def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):
Expand Down

0 comments on commit bdaab58

Please sign in to comment.