-
Notifications
You must be signed in to change notification settings - Fork 371
Fix for floating point representation attack #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f96d302
acb7854
62ac23a
6e98834
81ee589
dc3427c
ff2d642
6f0fae2
9530dfe
2c65570
2061942
0307f7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,19 +9,68 @@ | |
|
||
|
||
def _generate_noise( | ||
std: float, reference: torch.Tensor, generator=None | ||
std: float, | ||
reference: torch.Tensor, | ||
generator=None, | ||
secure_mode: bool = False, | ||
) -> torch.Tensor: | ||
if std > 0: | ||
# TODO: handle device transfers: generator and reference tensor | ||
# could be on different devices | ||
""" | ||
Generates noise according to a Gaussian distribution with mean 0 | ||
|
||
Args: | ||
std: Standard deviation of the noise | ||
reference: The reference Tensor to get the appripriate shape and device | ||
for generating the noise | ||
generator: The PyTorch noise generator | ||
secure_mode: boolean showing if "secure" noise need to be generate | ||
(see the notes) | ||
|
||
Notes: | ||
If `secure_mode` is enabled, the generated noise is also secure | ||
against the floating point representation attacks, such as the ones | ||
in https://arxiv.org/abs/2107.10138. This is achieved through calling | ||
the Gaussian noise function 2*n times, when n=2 (see section 5.1 in | ||
https://arxiv.org/abs/2107.10138). | ||
|
||
Reason for choosing n=2: n can be any number > 1. The bigger, the more | ||
computation needs to be done (`2n` Gaussian samples will be generated). | ||
The reason we chose `n=2` is that, `n=1` could be easy to break and `n>2` | ||
is not really necessary. The complexity of the attack is `2^p(2n-1)`. | ||
In PyTorch, `p=53` and so complexity is `2^53(2n-1)`. With `n=1`, we get | ||
`2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard | ||
enough for an attacker to break. | ||
""" | ||
zeros = torch.zeros(reference.shape, device=reference.device) | ||
if std == 0: | ||
return zeros | ||
# TODO: handle device transfers: generator and reference tensor | ||
# could be on different devices | ||
ashkan-software marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if secure_mode: | ||
torch.normal( | ||
mean=0, | ||
std=std, | ||
size=(1, 1), | ||
ashkan-software marked this conversation as resolved.
Show resolved
Hide resolved
|
||
device=reference.device, | ||
generator=generator, | ||
) # generate, but throw away first generated Gaussian sample | ||
sum = zeros | ||
for i in range(4): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I trust you and Ilya that it solves the problem, but could you pls do ELI5 why this works? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I understand, the sum from 1 to 4 gives you a Gaussian with variance 4 std^2, thus sum/2 is a gaussian with variance std^2. @ashkan-software : shouldn't you loop over only 2 samples as per the docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great question @ffuuugor. This approach is actually not any of those 3 options we had. I found this approach in a recent paper and me and Ilya think this is an intelligent way of fixing the problem. The idea is to invert the Gaussian mechanism and guess what values used as input to the mechanism. This is possible if Gaussian method is used once. But if we use the Gaussian more than once (in this fix, we call it 4 times), it becomes exponentially harder to guess those values. This is in very simple words, but the fix is a bit more involved and is explained in the paper I listed on the PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason for having number 4 and 2 in the code, is that when
|
||
sum += torch.normal( | ||
mean=0, | ||
std=std, | ||
size=reference.shape, | ||
device=reference.device, | ||
generator=generator, | ||
) | ||
ashkan-software marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return sum / 2 | ||
else: | ||
return torch.normal( | ||
mean=0, | ||
std=std, | ||
size=reference.shape, | ||
device=reference.device, | ||
generator=generator, | ||
) | ||
return torch.zeros(reference.shape, device=reference.device) | ||
|
||
|
||
def _get_flat_grad_sample(p: torch.Tensor): | ||
|
@@ -47,6 +96,7 @@ def __init__( | |
expected_batch_size: Optional[int], | ||
loss_reduction: str = "mean", | ||
generator=None, | ||
secure_mode=False, | ||
): | ||
if loss_reduction not in ("mean", "sum"): | ||
raise ValueError(f"Unexpected value for loss_reduction: {loss_reduction}") | ||
|
@@ -63,6 +113,7 @@ def __init__( | |
self.expected_batch_size = expected_batch_size | ||
self.step_hook = None | ||
self.generator = generator | ||
self.secure_mode = secure_mode | ||
|
||
self.param_groups = optimizer.param_groups | ||
self.state = optimizer.state | ||
|
@@ -137,6 +188,7 @@ def add_noise(self): | |
std=self.noise_multiplier * self.max_grad_norm, | ||
reference=p.summed_grad, | ||
generator=self.generator, | ||
secure_mode=self.secure_mode, | ||
) | ||
p.grad = p.summed_grad + noise | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.