-
Notifications
You must be signed in to change notification settings - Fork 350
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: This PR implements a new accountant `PRVAccountant` based on the paper [Numerical Composition of Differential Privacy](https://arxiv.org/abs/2106.02848). Code inspired heavily by the code that accompanied the paper: https://github.com/microsoft/prv_accountant ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue See #378 ## How Has This Been Tested (if it applies) I have tested these changes with the following scripts, but would welcome suggestions on how to test further or write unit tests to cover these changes: - [validate_gaussian.py](https://gist.github.com/tcbegley/13483859eb07488d711368c982af5ded) recreates [this notebook](https://github.com/microsoft/prv_accountant/blob/main/notebooks/validate-gaussian.ipynb), which checks that we can recover upper and lower bounds on the privacy curve of a Gaussian mechanism correctly. - [prv_accountant_cifar10.py](https://gist.github.com/tcbegley/91afccc8f702a61617f7ec6da250effe) runs [this tutorial from the Opacus docs](https://opacus.ai/tutorials/building_image_classifier) with the `PRVAccountant` instead of `RDPAccountant`. ## Checklist I have not yet written docstrings or tests for these changes both as it was slightly unclear to me how best to proceed, but also because I would like to validate the approach taken in this initial implementation before polishing. - [x] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: #493 Reviewed By: ffuuugor Differential Revision: D39208724 Pulled By: alexandresablayrolles fbshipit-source-id: c8949c2a61a0a6ed24628a5f53e597f2108a5b91
- Loading branch information
1 parent
c5562a7
commit c89a3bf
Showing
21 changed files
with
685 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from .compose import compose_heterogeneous | ||
from .domain import Domain, compute_safe_domain_size | ||
from .prvs import ( | ||
DiscretePRV, | ||
PoissonSubsampledGaussianPRV, | ||
TruncatedPrivacyRandomVariable, | ||
discretize, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"DiscretePRV", | ||
"Domain", | ||
"PoissonSubsampledGaussianPRV", | ||
"TruncatedPrivacyRandomVariable", | ||
"compose_heterogeneous", | ||
"compute_safe_domain_size", | ||
"discretize", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
from scipy.fft import irfft, rfft | ||
from scipy.signal import convolve | ||
|
||
from .prvs import DiscretePRV | ||
|
||
|
||
def _compose_fourier(dprv: DiscretePRV, num_self_composition: int) -> DiscretePRV: | ||
if len(dprv) % 2 != 0: | ||
raise ValueError("Can only compose evenly sized discrete PRVs") | ||
|
||
composed_pmf = irfft(rfft(dprv.pmf) ** num_self_composition) | ||
|
||
m = num_self_composition - 1 | ||
if num_self_composition % 2 == 0: | ||
m += len(composed_pmf) // 2 | ||
composed_pmf = np.roll(composed_pmf, m) | ||
|
||
domain = dprv.domain.shift_right(dprv.domain.shifts * (num_self_composition - 1)) | ||
|
||
return DiscretePRV(pmf=composed_pmf, domain=domain) | ||
|
||
|
||
def _compose_two(dprv_left: DiscretePRV, dprv_right: DiscretePRV) -> DiscretePRV: | ||
pmf = convolve(dprv_left.pmf, dprv_right.pmf, mode="same") | ||
domain = dprv_left.domain.shift_right(dprv_right.domain.shifts) | ||
return DiscretePRV(pmf=pmf, domain=domain) | ||
|
||
|
||
def _compose_convolution_tree(dprvs: List[DiscretePRV]) -> DiscretePRV: | ||
# repeatedly convolve neighbouring PRVs until we only have one left | ||
while len(dprvs) > 1: | ||
dprvs_conv = [] | ||
if len(dprvs) % 2 == 1: | ||
dprvs_conv.append(dprvs.pop()) | ||
|
||
for dprv_left, dprv_right in zip(dprvs[:-1:2], dprvs[1::2]): | ||
dprvs_conv.append(_compose_two(dprv_left, dprv_right)) | ||
|
||
dprvs = dprvs_conv | ||
return dprvs[0] | ||
|
||
|
||
def compose_heterogeneous( | ||
dprvs: List[DiscretePRV], num_self_compositions: List[int] | ||
) -> DiscretePRV: | ||
r""" | ||
Compose a heterogenous list of PRVs with multiplicity. We use FFT to compose | ||
identical PRVs with themselves first, then pairwise convolve the remaining PRVs. | ||
This is the approach taken in https://github.com/microsoft/prv_accountant | ||
""" | ||
if len(dprvs) != len(num_self_compositions): | ||
raise ValueError("dprvs and num_self_compositions must have the same length") | ||
|
||
dprvs = [ | ||
_compose_fourier(dprv, num_self_composition) | ||
for dprv, num_self_composition in zip(dprvs, num_self_compositions) | ||
] | ||
return _compose_convolution_tree(dprvs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from dataclasses import dataclass | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
|
||
from ...rdp import RDPAccountant | ||
|
||
|
||
@dataclass | ||
class Domain: | ||
r""" | ||
Stores relevant information about the domain on which PRVs are discretized, and | ||
includes a few convenience methods for manipulating it. | ||
""" | ||
t_min: float | ||
t_max: float | ||
size: int | ||
shifts: float = 0.0 | ||
|
||
def __post_init__(self): | ||
if not isinstance(self.size, int): | ||
raise TypeError("`size` must be an integer") | ||
if self.size % 2 != 0: | ||
raise ValueError("`size` must be even") | ||
|
||
@classmethod | ||
def create_aligned(cls, t_min: float, t_max: float, dt: float) -> "Domain": | ||
t_min = np.floor(t_min / dt) * dt | ||
t_max = np.ceil(t_max / dt) * dt | ||
|
||
size = int(np.round((t_max - t_min) / dt)) + 1 | ||
|
||
if size % 2 == 1: | ||
size += 1 | ||
t_max += dt | ||
|
||
domain = cls(t_min, t_max, size) | ||
|
||
if np.abs(domain.dt - dt) / dt >= 1e-8: | ||
raise RuntimeError | ||
|
||
return domain | ||
|
||
def shift_right(self, dt: float) -> "Domain": | ||
return Domain( | ||
t_min=self.t_min + dt, | ||
t_max=self.t_max + dt, | ||
size=self.size, | ||
shifts=self.shifts + dt, | ||
) | ||
|
||
@property | ||
def dt(self): | ||
return (self.t_max - self.t_min) / (self.size - 1) | ||
|
||
@property | ||
def ts(self): | ||
return np.linspace(self.t_min, self.t_max, self.size) | ||
|
||
def __getitem__(self, i: int) -> float: | ||
return self.t_min + i * self.dt | ||
|
||
|
||
def compute_safe_domain_size( | ||
prvs, | ||
max_self_compositions: Sequence[int], | ||
eps_error: float, | ||
delta_error: float, | ||
) -> float: | ||
""" | ||
Compute safe domain size for the discretization of the PRVs. | ||
For details about this algorithm, see remark 5.6 in | ||
https://www.microsoft.com/en-us/research/publication/numerical-composition-of-differential-privacy/ | ||
""" | ||
total_compositions = sum(max_self_compositions) | ||
|
||
rdp_accountant = RDPAccountant() | ||
for prv, max_self_composition in zip(prvs, max_self_compositions): | ||
rdp_accountant.history.append( | ||
(prv.noise_multiplier, prv.sample_rate, max_self_composition) | ||
) | ||
|
||
L_max = rdp_accountant.get_epsilon(delta_error / 4) | ||
|
||
for prv, max_self_composition in zip(prvs, max_self_compositions): | ||
rdp_accountant = RDPAccountant() | ||
rdp_accountant.history = [(prv.noise_multiplier, prv.sample_rate, 1)] | ||
L_max = max( | ||
L_max, | ||
rdp_accountant.get_epsilon(delta=delta_error / (8 * total_compositions)), | ||
) | ||
|
||
# FIXME: this implementation is adapted from the code accompanying the paper, but | ||
# disagrees subtly with the theory from remark 5.6. It's not immediately clear this | ||
# gives the right guarantees in all cases, though it's fine for eps_error < 1 and | ||
# hence generic cases. | ||
# cf. https://github.com/microsoft/prv_accountant/discussions/34 | ||
return max(L_max, eps_error) + 3 |
Oops, something went wrong.