Skip to content

Migrate GradSampler to Tensor hooks #259

Open
@romovpa

Description

@romovpa

Problem & Motivation

The grad sampler currently relies on nn.Module backward_hook. This approach has a limitation: it only covers parts of the expression where the module is called directly x = module(x). If someone adds an expression based on the module's parameters to the loss, it won't have effect to the gradients.

An example where this problem occurs is adding a parameter regularizer in the loss (see #249), for example:

loss = criterion(y_pred, y_true)
loss += l2_loss(model.parameters())
loss += proximal_loss(model, another_model)   # e.g. encourage two models to have similar weights
loss.backward()

In this case, grad hooks are just not called. When running with PrivacyEngine, backward() silently omits the regularizer term. This is surprising and incorrect behaviour.

Using nn.Module hooks is a fundamental limitation. Migrating to full_backward_hooks doesn't solve the problem.

Pitch

Bare minimum:

  • Make the behaviour of grad sampler correct. If we cannot handle the loss, we should fire an exception.
  • Document a workaround for this limitation. For example, we may suggest to implement a custom Module with the corresponding grad sampler.

Ideally:

  • Migrate grad sampler to use low-level Tensor hooks to guarantee that no part of the gradient is missed.

Alternatives

Suggestions are welcome.

Metadata

Metadata

Labels

bugSomething isn't workingrefactorRefactor of the existing code

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions