Description
Description & Motivation
Our current implementation of gradient clipping for FSDP is limited to clipping by value only. Norm is not supported:
pytorch-lightning/src/lightning/pytorch/plugins/precision/fsdp.py
Lines 77 to 84 in f75f3bc
The reason is that clipping by norm needs to be called through the FSDP API and this wasn't realized in Lightning yet, because it can't be done directly through the optimizer (the FSDP module reference is required): https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
Pitch
Support clipping by norm.
Change the API from
class Precision:
...
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
...
to
class Precision:
...
def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
...
to take the module as input. The implementation in FSDPPrecision
would then call module.clip_grad_norm()
instead of torch.nn.utils.clip_grad_norm_
.
The LightningModule.clip_gradients()
method should then pass self.trainer.model
to self.trainer.precision_plugin.clip_gradients()
.
Alternatives
There is not much else we can do. I believe the proposal above will lead to the least amount of breaking changes (only affects the signature of the precision plugin methods.
Additional context
In Fabric's precision plugins, this is already done. We would need to do this on the Trainer side anyway sooner or later, if we want to unify the precision/strategy implementations.