Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix per-layer clipping in distributed #347

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions opacus/optimizers/ddp_perlayeroptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from functools import partial
from typing import List, Optional
from typing import Callable, List, Optional

import torch
from torch import nn
Expand All @@ -29,8 +29,7 @@ def _clip_and_accumulate_parameter(p: nn.Parameter, max_grad_norm: float):
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(max=1.0)

grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample)

if hasattr(p, "summed_grad"):
if p.summed_grad is not None:
p.summed_grad += grad
else:
p.summed_grad = grad
Expand All @@ -47,14 +46,15 @@ def __init__(
optimizer: Optimizer,
*,
noise_multiplier: float,
max_grad_norms: List[float],
max_grad_norm: List[float],
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
):
self.rank = torch.distributed.get_rank()
self.max_grad_norms = max_grad_norms
self.world_size = torch.distributed.get_world_size()
self.max_grad_norms = max_grad_norm
max_grad_norm = torch.norm(torch.Tensor(self.max_grad_norms), p=2).item()
super().__init__(
optimizer,
Expand All @@ -79,12 +79,18 @@ def _add_noise_parameter(self, p: nn.Parameter):
)
p.grad = p.summed_grad + noise

@property
def accumulated_iterations(self) -> int:
return max([p.accumulated_iterations for p in self.params])

def _scale_grad_parameter(self, p: nn.Parameter):
if not hasattr(p, "accumulated_iterations"):
p.accumulated_iterations = 0
p.accumulated_iterations += 1
if self.loss_reduction == "mean":
p.grad /= self.expected_batch_size * p.accumulated_iterations
p.grad /= (
self.expected_batch_size * p.accumulated_iterations * self.world_size
)

def clip_and_accumulate(self):
raise NotImplementedError(
Expand All @@ -94,20 +100,30 @@ def clip_and_accumulate(self):
def add_noise(self):
raise NotImplementedError("Noise is added per layer in DPDDP Per Layer.")

def pre_step(self):
self.accumulated_iterations = max(
[p.accumulated_iterations for p in self.params]
)
def pre_step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
if self._check_skip_next_step():
self._is_last_step_skipped = True
return False

if self.step_hook:
self.step_hook(self)
self.accumulated_iterations = 0

for p in self.params:
p.accumulated_iterations = 0

self._is_last_step_skipped = False
return True

def _ddp_per_layer_hook(
self, p: nn.Parameter, max_grad_norm: float, _: torch.Tensor
):
_clip_and_accumulate_parameter(p, max_grad_norm)
# Equivalent ot _check_skip_next_step but without popping because it has to be done for every parameter p
if self._check_skip_next_step(pop_next=False):
return

if self.rank == 0:
self._add_noise_parameter(p)
else:
Expand Down
2 changes: 1 addition & 1 deletion opacus/optimizers/ddpoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def add_noise(self):
super().add_noise()
else:
for p in self.params:
p.grad = p.summed_grad
p.grad = p.summed_grad.view_as(p.grad)

def reduce_gradients(self):
for p in self.params:
Expand Down
15 changes: 12 additions & 3 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,18 @@ def signal_skip_step(self, do_skip=True):
"""
self._step_skip_queue.append(do_skip)

def _check_skip_next_step(self):
def _check_skip_next_step(self, pop_next=True):
"""
Checks if next step should be skipped by the optimizer.
This is for large Poisson batches that get split into smaller physical batches
to fit on the device. Batches that do not correspond to the end of a Poisson
batch or thus `skipped` as their gradient gets accumulated for one big step.
"""
if self._step_skip_queue:
return self._step_skip_queue.pop(0)
if pop_next:
return self._step_skip_queue.pop(0)
else:
return self._step_skip_queue[0]
else:
return False

Expand Down Expand Up @@ -418,7 +427,7 @@ def add_noise(self):
generator=self.generator,
secure_mode=self.secure_mode,
)
p.grad = p.summed_grad + noise
p.grad = (p.summed_grad + noise).view_as(p.grad)

_mark_as_processed(p.summed_grad)

Expand Down
5 changes: 3 additions & 2 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from opacus.optimizers import DPOptimizer, get_optimizer_class
from opacus.validators.module_validator import ModuleValidator
from torch import nn, optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -340,7 +341,7 @@ def make_private(
if noise_generator and self.secure_mode:
raise ValueError("Passing seed is prohibited in secure mode")

distributed = type(module) is DPDDP
distributed = isinstance(module, (DPDDP, DDP))

module = self._prepare_model(
module, batch_first=batch_first, loss_reduction=loss_reduction
Expand All @@ -355,7 +356,7 @@ def make_private(
sample_rate = 1 / len(data_loader)
expected_batch_size = int(len(data_loader.dataset) * sample_rate)

# expected_batch_size should be the *total* batch size across workers
# expected_batch_size is the *per worker* batch size
if distributed:
world_size = torch.distributed.get_world_size()
expected_batch_size /= world_size
Expand Down
22 changes: 15 additions & 7 deletions opacus/tests/multigpu_gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import os
import sys
import unittest
from opacus.optimizers.ddp_perlayeroptimizer import DistributedPerLayerOptimizer
from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer

import torch
import torch.distributed as dist
Expand All @@ -26,6 +28,7 @@
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler


PRIVACY_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
Expand Down Expand Up @@ -66,18 +69,17 @@ def forward(self, x):
def demo_basic(rank, weight, world_size, dp, clipping):
torch.manual_seed(world_size)
batch_size = 32
withdp = "with" + ("out " if not dp else "")
print(f"Running basic DDP {withdp} differential privacy example on rank {rank}.")
setup(rank, world_size)

# create model and move it to GPU with id rank
model = ToyModel().to(rank)
model.net1.weight.data.zero_()
optimizer = optim.SGD(model.parameters(), lr=1)

labels = torch.randn(batch_size, 5).to(rank)
data = torch.randn(batch_size, 10)
labels = torch.randn(2 * batch_size, 5).to(rank)
data = torch.randn(2 * batch_size, 10)

data_loader = DataLoader(TensorDataset(data, labels), batch_size=batch_size)
dataset = TensorDataset(data, labels)

loss_fn = nn.MSELoss()
if dp and clipping == "flat":
Expand All @@ -87,10 +89,12 @@ def demo_basic(rank, weight, world_size, dp, clipping):

privacy_engine = PrivacyEngine()

sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
if dp:
max_grad_norm = 1e8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's follow pep-8 and name this constant with capital letters: MAX_GRAD_NORM. This will emphasize that this thing is pre-defined and hard coded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a constant: if clipping is per-layer, it becomes a list, see line 80.

if clipping == "per_layer":
max_grad_norm = [1e8 for p in model.parameters()]
max_grad_norm = [max_grad_norm for _ in model.parameters()]
ddp_model, optimizer, data_loader = privacy_engine.make_private(
module=ddp_model,
optimizer=optimizer,
Expand All @@ -100,12 +104,16 @@ def demo_basic(rank, weight, world_size, dp, clipping):
poisson_sampling=False,
clipping=clipping,
)
if clipping == "per_layer":
assert isinstance(optimizer, DistributedPerLayerOptimizer)
else:
assert isinstance(optimizer, DistributedDPOptimizer)

optimizer.zero_grad()

for x, y in data_loader:
outputs = ddp_model(x.to(rank))
loss = loss_fn(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
break
Expand Down