Skip to content

Support gradient clipping by norm with FSDP #19235

Open
@awaelchli

Description

@awaelchli

Description & Motivation

Our current implementation of gradient clipping for FSDP is limited to clipping by value only. Norm is not supported:

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)

The reason is that clipping by norm needs to be called through the FSDP API and this wasn't realized in Lightning yet, because it can't be done directly through the optimizer (the FSDP module reference is required): https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

Pitch

Support clipping by norm.

Change the API from

class Precision:
    ...
    def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
        ...

to

class Precision:
    ...
    def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
        ...

to take the module as input. The implementation in FSDPPrecision would then call module.clip_grad_norm() instead of torch.nn.utils.clip_grad_norm_.

The LightningModule.clip_gradients() method should then pass self.trainer.model to self.trainer.precision_plugin.clip_gradients().

Alternatives

There is not much else we can do. I believe the proposal above will lead to the least amount of breaking changes (only affects the signature of the precision plugin methods.

Additional context

In Fabric's precision plugins, this is already done. We would need to do this on the Trainer side anyway sooner or later, if we want to unify the precision/strategy implementations.

cc @Borda @awaelchli @carmocca

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementstrategy: fsdpFully Sharded Data Parallel

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions