diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index f659f357..19b5ffa6 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -105,10 +105,9 @@ def __init__( ``[K, batch_size, ...]`` loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values "sum" or "mean" - strict: If set to ``True``, the input module will be validated to check that - ``GradSampleModule`` has grad sampler functions for all submodules of - the input module (i.e. if it knows how to calculate per sample gradients) - for all model parameters. If set to ``False``, per sample gradients will + strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers, + which is not currently supported by Opacus. + If set to ``False``, per sample gradients will be computed on "best effort" basis - they will be available where possible and set to None otherwise. This is not recommended, because some unsupported modules (e.g. BatchNorm) affect other parameters and @@ -120,7 +119,7 @@ def __init__( Raises: NotImplementedError If ``strict`` is set to ``True`` and module ``m`` (or any of its - submodules) doesn't have a registered grad sampler function. + submodules) includes a buffer. """ super().__init__( m, diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index 8e23b9b3..deaeb385 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -107,13 +107,15 @@ def __init__( Raises: NotImplementedError If ``strict`` is set to ``True`` and module ``m`` (or any of its - submodules) doesn't have a registered grad sampler function. + submodules) includes a buffer. """ super().__init__( m, batch_first=batch_first, loss_reduction=loss_reduction, + strict=strict, + force_functorch=force_functorch, ) self.trainable_parameters = [p for _, p in trainable_parameters(self._module)] self.max_grad_norm = max_grad_norm