Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gradient clipping by norm with FSDP #19235

Open
awaelchli opened this issue Jan 4, 2024 · 2 comments
Open

Support gradient clipping by norm with FSDP #19235

awaelchli opened this issue Jan 4, 2024 · 2 comments
Labels
feature Is an improvement or enhancement strategy: fsdp Fully Sharded Data Parallel
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Jan 4, 2024

Description & Motivation

Our current implementation of gradient clipping for FSDP is limited to clipping by value only. Norm is not supported:

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)

The reason is that clipping by norm needs to be called through the FSDP API and this wasn't realized in Lightning yet, because it can't be done directly through the optimizer (the FSDP module reference is required): https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

Pitch

Support clipping by norm.

Change the API from

class Precision:
    ...
    def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
        ...

to

class Precision:
    ...
    def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
        ...

to take the module as input. The implementation in FSDPPrecision would then call module.clip_grad_norm() instead of torch.nn.utils.clip_grad_norm_.

The LightningModule.clip_gradients() method should then pass self.trainer.model to self.trainer.precision_plugin.clip_gradients().

Alternatives

There is not much else we can do. I believe the proposal above will lead to the least amount of breaking changes (only affects the signature of the precision plugin methods.

Additional context

In Fabric's precision plugins, this is already done. We would need to do this on the Trainer side anyway sooner or later, if we want to unify the precision/strategy implementations.

cc @Borda @awaelchli @carmocca

@awaelchli awaelchli added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Jan 4, 2024
@awaelchli awaelchli added this to the 2.2 milestone Jan 4, 2024
@awaelchli awaelchli modified the milestones: 2.2, 2.3 Feb 3, 2024
@xin-w8023
Copy link

Any updates about this?

@awaelchli awaelchli modified the milestones: 2.3, future Jun 2, 2024
@amorehead
Copy link
Contributor

Agreed, any updates?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

No branches or pull requests

3 participants