Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learning rate scheduling support for DeepSpeedStrategy #20320

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
188a45f
Update fabric.py
amorehead Oct 5, 2024
baf5988
Update deepspeed.py
amorehead Oct 5, 2024
1f4c18e
Update deepspeed.py
amorehead Oct 5, 2024
585e302
Update fabric.py
amorehead Oct 5, 2024
0451761
Update fsdp.py
amorehead Oct 5, 2024
a912aab
Update strategy.py
amorehead Oct 5, 2024
d27d4a3
Update strategy.py
amorehead Oct 5, 2024
67089a1
Update xla_fsdp.py
amorehead Oct 5, 2024
1025875
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2024
9b45b99
Update fsdp.py
amorehead Oct 5, 2024
a7a5835
Update strategy.py
amorehead Oct 5, 2024
3ece31c
Update xla_fsdp.py
amorehead Oct 5, 2024
e48acd2
Update deepspeed.py
amorehead Oct 5, 2024
f13516d
Update seed.py
amorehead Oct 28, 2024
80b4a6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
2cab7e2
Update seed.py
amorehead Oct 28, 2024
e9127f4
Update seed.py
amorehead Oct 28, 2024
c127458
Update seed.py
amorehead Oct 28, 2024
31a1fce
Merge branch 'master' into patch-2
lantiga Nov 12, 2024
f215626
Merge branch 'master' into patch-2
amorehead Nov 15, 2024
dfce07e
Merge branch 'master' into patch-2
lantiga Nov 25, 2024
25e8d48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
737162d
Update src/lightning/fabric/strategies/deepspeed.py
lantiga Nov 25, 2024
2d347d0
Update src/lightning/fabric/strategies/fsdp.py
lantiga Nov 25, 2024
5d227ff
Update src/lightning/fabric/strategies/strategy.py
lantiga Nov 25, 2024
f94efa7
Update src/lightning/fabric/strategies/xla_fsdp.py
lantiga Nov 25, 2024
c2613ec
Merge branch 'Lightning-AI:master' into patch-2
amorehead Jan 9, 2025
56464ed
Update deepspeed.py
amorehead Jan 9, 2025
e09941c
Update fabric.py
amorehead Jan 9, 2025
3709f1d
Update fabric_methods.rst
amorehead Jan 10, 2025
13195a2
Update wrappers.rst
amorehead Jan 10, 2025
3791c1b
Merge branch 'Lightning-AI:master' into patch-2
amorehead Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update fabric.py
  • Loading branch information
amorehead authored Oct 5, 2024
commit 188a45f9b8d35faef4952848e0cad85b05bb1c78
42 changes: 19 additions & 23 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from lightning_utilities.core.overrides import is_overridden
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

import lightning.fabric
Expand Down Expand Up @@ -208,71 +209,66 @@ def run(self, *args: Any, **kwargs: Any) -> Any:

"""

def setup(
self,
module: nn.Module,
*optimizers: Optimizer,
move_to_device: bool = True,
_reapply_compile: bool = True,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[LRScheduler] = None, move_to_device: bool = True, _reapply_compile: bool = True,) -> Any: # no specific return because the way we want our API to look does not play well with mypy
r"""Set up a model and its optimizers for accelerated training.

Args:
module: A :class:`torch.nn.Module` to set up
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible)
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.

Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.

"""
self._validate_setup(module, optimizers)
module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
original_module = module

module = self._precision.convert_module(module)

if move_to_device:
module = self._move_model_to_device(model=module, optimizers=list(optimizers))

# Let accelerator/plugin wrap and connect the models and optimizers
if optimizers:
module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers)
module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers), scheduler
)
else:
module = self._strategy.setup_module(module)

if compile_kwargs is not None:
module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._strategy, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
# NOTE: for sharded strategies or manual device placement, there's no single root device
_update_properties(
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
)

optimizers = [_FabricOptimizer(optimizer, self._strategy, self._callbacks) for optimizer in optimizers]

self._models_setup += 1

if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self
original_module._fabric_optimizers = optimizers
if original_module not in self._callbacks:
self._callbacks.append(original_module)

self.call("on_after_setup", fabric=self, module=module)

if optimizers:
# join both types in a tuple for API convenience
return (module, *optimizers)
return (module, *optimizers, scheduler)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, it will cause existing user code to fail, because scheduler is returned unconditionally.

Since scheduler is Optional in the signature, I suggest we only return it if it was not None as an argument, so we won't break anyone's code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Addressed in this commit.

return module

def setup_module(
Expand Down