Skip to content

Commit 3f3cc88

Browse files
Fix for floating point representation attack (#260)
* Fix for floating point representation attack * Fix for floating point representation attack * Revert "Fix for floating point representation attack" This reverts commit acb7854. * Fix small bug * fix black lint formatting * Address comments regarding explanation * Add new test to ensure generate_noise is correct - fix test_noise_level * fix lint issues * Update privacy_engine_test.py * Fix isort lint issue * Update privacy_engine_test.py
1 parent 461b72f commit 3f3cc88

File tree

7 files changed

+141
-38
lines changed

7 files changed

+141
-38
lines changed

opacus/optimizers/ddp_perlayeroptimizer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
expected_batch_size: Optional[int],
3333
loss_reduction: str = "mean",
3434
generator=None,
35+
secure_mode=False,
3536
):
3637
self.rank = torch.distributed.get_rank()
3738
self.max_grad_norms = max_grad_norms
@@ -43,6 +44,7 @@ def __init__(
4344
expected_batch_size=expected_batch_size,
4445
loss_reduction=loss_reduction,
4546
generator=generator,
47+
secure_mode=secure_mode,
4648
)
4749
self.register_hooks()
4850

@@ -51,7 +53,10 @@ def _add_noise_parameter(self, p):
5153
The reason why we need self is because of generator for secure_mode
5254
"""
5355
noise = _generate_noise(
54-
self.noise_multiplier * self.max_grad_norm, p.summed_grad
56+
std=self.noise_multiplier * self.max_grad_norm,
57+
reference=p.summed_grad,
58+
generator=None,
59+
secure_mode=self.secure_mode,
5560
)
5661
p.grad = p.summed_grad + noise
5762

opacus/optimizers/ddpoptimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
expected_batch_size: Optional[int],
1919
loss_reduction: str = "mean",
2020
generator=None,
21+
secure_mode=False,
2122
):
2223
super().__init__(
2324
optimizer,
@@ -26,6 +27,7 @@ def __init__(
2627
expected_batch_size=expected_batch_size,
2728
loss_reduction=loss_reduction,
2829
generator=generator,
30+
secure_mode=secure_mode,
2931
)
3032
self.rank = torch.distributed.get_rank()
3133
self.world_size = torch.distributed.get_world_size()

opacus/optimizers/optimizer.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,68 @@
99

1010

1111
def _generate_noise(
12-
std: float, reference: torch.Tensor, generator=None
12+
std: float,
13+
reference: torch.Tensor,
14+
generator=None,
15+
secure_mode: bool = False,
1316
) -> torch.Tensor:
14-
if std > 0:
15-
# TODO: handle device transfers: generator and reference tensor
16-
# could be on different devices
17+
"""
18+
Generates noise according to a Gaussian distribution with mean 0
19+
20+
Args:
21+
std: Standard deviation of the noise
22+
reference: The reference Tensor to get the appripriate shape and device
23+
for generating the noise
24+
generator: The PyTorch noise generator
25+
secure_mode: boolean showing if "secure" noise need to be generate
26+
(see the notes)
27+
28+
Notes:
29+
If `secure_mode` is enabled, the generated noise is also secure
30+
against the floating point representation attacks, such as the ones
31+
in https://arxiv.org/abs/2107.10138. This is achieved through calling
32+
the Gaussian noise function 2*n times, when n=2 (see section 5.1 in
33+
https://arxiv.org/abs/2107.10138).
34+
35+
Reason for choosing n=2: n can be any number > 1. The bigger, the more
36+
computation needs to be done (`2n` Gaussian samples will be generated).
37+
The reason we chose `n=2` is that, `n=1` could be easy to break and `n>2`
38+
is not really necessary. The complexity of the attack is `2^p(2n-1)`.
39+
In PyTorch, `p=53` and so complexity is `2^53(2n-1)`. With `n=1`, we get
40+
`2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard
41+
enough for an attacker to break.
42+
"""
43+
zeros = torch.zeros(reference.shape, device=reference.device)
44+
if std == 0:
45+
return zeros
46+
# TODO: handle device transfers: generator and reference tensor
47+
# could be on different devices
48+
if secure_mode:
49+
torch.normal(
50+
mean=0,
51+
std=std,
52+
size=(1, 1),
53+
device=reference.device,
54+
generator=generator,
55+
) # generate, but throw away first generated Gaussian sample
56+
sum = zeros
57+
for i in range(4):
58+
sum += torch.normal(
59+
mean=0,
60+
std=std,
61+
size=reference.shape,
62+
device=reference.device,
63+
generator=generator,
64+
)
65+
return sum / 2
66+
else:
1767
return torch.normal(
1868
mean=0,
1969
std=std,
2070
size=reference.shape,
2171
device=reference.device,
2272
generator=generator,
2373
)
24-
return torch.zeros(reference.shape, device=reference.device)
2574

2675

2776
def _get_flat_grad_sample(p: torch.Tensor):
@@ -47,6 +96,7 @@ def __init__(
4796
expected_batch_size: Optional[int],
4897
loss_reduction: str = "mean",
4998
generator=None,
99+
secure_mode=False,
50100
):
51101
if loss_reduction not in ("mean", "sum"):
52102
raise ValueError(f"Unexpected value for loss_reduction: {loss_reduction}")
@@ -63,6 +113,7 @@ def __init__(
63113
self.expected_batch_size = expected_batch_size
64114
self.step_hook = None
65115
self.generator = generator
116+
self.secure_mode = secure_mode
66117

67118
self.param_groups = optimizer.param_groups
68119
self.state = optimizer.state
@@ -137,6 +188,7 @@ def add_noise(self):
137188
std=self.noise_multiplier * self.max_grad_norm,
138189
reference=p.summed_grad,
139190
generator=self.generator,
191+
secure_mode=self.secure_mode,
140192
)
141193
p.grad = p.summed_grad + noise
142194

opacus/optimizers/perlayeroptimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
expected_batch_size: Optional[int],
2020
loss_reduction: str = "mean",
2121
generator=None,
22+
secure_mode=False,
2223
):
2324
assert len(max_grad_norm) == len(params(optimizer))
2425
self.max_grad_norms = max_grad_norm
@@ -30,6 +31,7 @@ def __init__(
3031
expected_batch_size=expected_batch_size,
3132
loss_reduction=loss_reduction,
3233
generator=generator,
34+
secure_mode=secure_mode,
3335
)
3436

3537
def attach(self, optimizer):

opacus/privacy_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _prepare_optimizer(
106106
expected_batch_size=expected_batch_size,
107107
loss_reduction=loss_reduction,
108108
generator=generator,
109+
secure_mode=self.secure_mode,
109110
)
110111

111112
def _prepare_data_loader(
@@ -266,4 +267,4 @@ def make_private_with_epsilon(
266267
)
267268

268269
def get_epsilon(self, delta):
269-
return self.accountant.get_epsilon(delta)
270+
return self.accountant.get_epsilon(delta)

opacus/tests/privacy_engine_test.py

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest
55
from abc import ABC
66
from typing import Optional, OrderedDict
7+
from unittest.mock import MagicMock, patch
78

89
import hypothesis.strategies as st
910
import torch
@@ -12,6 +13,7 @@
1213
from hypothesis import given, settings
1314
from opacus import PrivacyEngine
1415
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
16+
from opacus.optimizers.optimizer import _generate_noise
1517
from opacus.utils.module_utils import are_state_dict_equal
1618
from opacus.validators.errors import UnsupportedModuleError
1719
from torch.utils.data import DataLoader, Dataset
@@ -395,42 +397,81 @@ def test_noise_level(self, noise_multiplier: float, max_steps: int):
395397
"""
396398
Tests that the noise level is correctly set
397399
"""
398-
# Initialize models with parameters to zero
399-
model, optimizer, dl, _ = self._init_private_training(
400-
noise_multiplier=noise_multiplier
401-
)
402-
for p in model.parameters():
403-
p.data.zero_()
404400

405-
# Do max_steps steps of DP-SGD
406-
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
407-
steps = 0
408-
for x, y in dl:
409-
optimizer.zero_grad()
410-
logits = model(x)
411-
loss = logits.view(logits.size(0), -1).sum(dim=1)
412-
# Gradient should be 0
413-
loss.backward(torch.zeros(logits.size(0)))
401+
def helper_test_noise_level(
402+
noise_multiplier: float, max_steps: int, secure_mode: bool
403+
):
404+
torch.manual_seed(100)
405+
# Initialize models with parameters to zero
406+
model, optimizer, dl, _ = self._init_private_training(
407+
noise_multiplier=noise_multiplier,
408+
secure_mode=secure_mode,
409+
)
410+
for p in model.parameters():
411+
p.data.zero_()
414412

415-
optimizer.step()
416-
steps += 1
413+
# Do max_steps steps of DP-SGD
414+
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
415+
steps = 0
416+
for x, y in dl:
417+
optimizer.zero_grad()
418+
logits = model(x)
419+
loss = logits.view(logits.size(0), -1).sum(dim=1)
420+
# Gradient should be 0
421+
loss.backward(torch.zeros(logits.size(0)))
417422

418-
if max_steps and steps >= max_steps:
419-
break
423+
optimizer.step()
424+
steps += 1
425+
426+
if max_steps and steps >= max_steps:
427+
break
428+
429+
# Noise should be equal to lr*sigma*sqrt(n_params * steps) / batch_size
430+
expected_norm = (
431+
steps
432+
* n_params
433+
* optimizer.noise_multiplier ** 2
434+
* self.LR ** 2
435+
/ (optimizer.expected_batch_size ** 2)
436+
)
437+
real_norm = sum(
438+
[torch.sum(torch.pow(p.data, 2)) for p in model.parameters()]
439+
).item()
420440

421-
# Noise should be equal to lr*sigma*sqrt(n_params * steps) / batch_size
422-
expected_norm = (
423-
steps
424-
* n_params
425-
* optimizer.noise_multiplier ** 2
426-
* self.LR ** 2
427-
/ (optimizer.expected_batch_size ** 2)
428-
)
429-
real_norm = sum(
430-
[torch.sum(torch.pow(p.data, 2)) for p in model.parameters()]
431-
).item()
441+
self.assertAlmostEqual(real_norm, expected_norm, delta=0.15 * expected_norm)
442+
443+
with self.subTest(secure_mode=False):
444+
helper_test_noise_level(
445+
noise_multiplier=noise_multiplier,
446+
max_steps=max_steps,
447+
secure_mode=False,
448+
)
449+
with self.subTest(secure_mode=True):
450+
helper_test_noise_level(
451+
noise_multiplier=noise_multiplier,
452+
max_steps=max_steps,
453+
secure_mode=True,
454+
)
432455

433-
self.assertAlmostEqual(real_norm, expected_norm, delta=0.1 * expected_norm)
456+
@patch("torch.normal", MagicMock(return_value=torch.Tensor([0.6])))
457+
def test_generate_noise_in_secure_mode(self):
458+
"""
459+
Tests that the noise is added correctly in secure_mode,
460+
according to section 5.1 in https://arxiv.org/abs/2107.10138.
461+
Since n=2, noise should be summed 4 times and divided by 2.
462+
463+
In this example, torch.normal returns a constant value of 0.6.
464+
So, the overal noise would be (0.6 + 0.6 + 0.6 + 0.6)/2 = 1.2
465+
"""
466+
noise = _generate_noise(
467+
std=2.0,
468+
reference=torch.Tensor([1, 2, 3]), # arbitrary size = 3
469+
secure_mode=True,
470+
)
471+
self.assertTrue(
472+
torch.allclose(noise, torch.Tensor([1.2, 1.2, 1.2])),
473+
"Model parameters after deterministic run must match",
474+
)
434475

435476

436477
class SampleConvNet(nn.Module):

run_results_imdb_classification.pt

559 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)