Skip to content

Commit

Permalink
Add option to specify number of steps in UniformSampler (#550)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #550

Add option to specify number of steps in UniformSampler

Differential Revision: D42034268

fbshipit-source-id: c0c36b8f3bfe8903b717e2b2203ae7b68f493074
  • Loading branch information
Alex Sablayrolles authored and facebook-github-bot committed Dec 15, 2022
1 parent 0b20059 commit 68da684
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions opacus/utils/uniform_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ class UniformWithReplacementSampler(Sampler[List[int]]):
Each sample is selected with a probability equal to ``sample_rate``.
"""

def __init__(self, *, num_samples: int, sample_rate: float, generator=None):
def __init__(
self, *, num_samples: int, sample_rate: float, generator=None, steps=None
):
r"""
Args:
num_samples: number of samples to draw.
sample_rate: probability used in sampling.
generator: Generator used in sampling.
steps: Number of steps (iterations of the Sampler)
"""
self.num_samples = num_samples
self.sample_rate = sample_rate
Expand All @@ -42,11 +45,16 @@ def __init__(self, *, num_samples: int, sample_rate: float, generator=None):
"value, but got num_samples={}".format(self.num_samples)
)

if steps is not None:
self.steps = steps
else:
self.steps = int(1 / self.sample_rate)

def __len__(self):
return int(1 / self.sample_rate)
return self.steps

def __iter__(self):
num_batches = int(1 / self.sample_rate)
num_batches = self.steps
while num_batches > 0:
mask = (
torch.rand(self.num_samples, generator=self.generator)
Expand Down Expand Up @@ -82,6 +90,7 @@ def __init__(
sample_rate: float,
shuffle: bool = True,
shuffle_seed: int = 0,
steps: int = None,
generator=None,
):
"""
Expand Down Expand Up @@ -117,7 +126,10 @@ def __init__(
self.num_samples += 1

# Number of batches: same as non-distributed Poisson sampling, but each batch is smaller
self.num_batches = int(1 / self.sample_rate)
if steps is not None:
self.num_batches = steps
else:
self.num_batches = int(1 / self.sample_rate)

def __iter__(self):
if self.shuffle:
Expand Down

0 comments on commit 68da684

Please sign in to comment.