Skip to content

Commit 4510351

Browse files
authored
Delay Precision.convert_module until configure_model has run (#19061)
1 parent db2cc8a commit 4510351

File tree

11 files changed

+155
-41
lines changed

11 files changed

+155
-41
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242
### Fixed
4343

4444
- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
45+
- Fixed issue where layers created in `LightningModule.setup` or `LightningModule.configure_model` wouldn't get converted when using the Bitsandbytes or TransformerEngine plugins ([#19061](https://github.com/Lightning-AI/lightning/pull/19061))
4546
- Fixed the input validation logic in `FSDPStrategy` to accept a `device_mesh` ([#19392](https://github.com/Lightning-AI/lightning/pull/19392))
4647

4748

src/lightning/pytorch/strategies/ddp.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -158,34 +158,31 @@ def setup(self, trainer: "pl.Trainer") -> None:
158158
assert self.accelerator is not None
159159
self.accelerator.setup(trainer)
160160

161-
# move the model to the correct device
162-
self.model_to_device()
163-
164-
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
165161
trainer_fn = trainer.state.fn
166-
162+
assert self.model is not None
167163
if trainer_fn == TrainerFn.FITTING and self._layer_sync:
168-
assert self.model is not None
169164
self.model = self._layer_sync.apply(self.model)
170165

171-
self.setup_precision_plugin()
166+
self.precision_plugin.convert_module(self.model)
167+
self.model_to_device()
172168

173169
if trainer_fn == TrainerFn.FITTING:
174170
# do not wrap with DDP if not fitting as there's no gradients to reduce
175171
self.configure_ddp()
176172

177173
# set up optimizers after the wrapped module has been moved to the device
178174
self.setup_optimizers(trainer)
175+
else:
176+
# we need to manually synchronize the module's states since we aren't using the DDP wrapper
177+
_sync_module_states(self.model)
178+
self.setup_precision_plugin()
179+
if trainer_fn == TrainerFn.FITTING:
179180
_optimizers_to_device(self.optimizers, self.root_device)
180181

181182
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
182183

183184
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
184185
self._enable_model_averaging()
185-
else:
186-
# we need to manually synchronize the module's states since we aren't using the DDP wrapper
187-
assert self.model is not None
188-
_sync_module_states(self.model)
189186

190187
@override
191188
def _setup_model(self, model: Module) -> DistributedDataParallel:

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,21 @@ def setup(self, trainer: "pl.Trainer") -> None:
335335
self._init_config_if_needed()
336336
assert self.accelerator is not None
337337
self.accelerator.setup(trainer)
338+
338339
# we set the device so that optimizers can be created with distributed comms.
339340
assert self.lightning_module is not None
340341
self.lightning_module._device = self.root_device
341-
self.setup_optimizers(trainer)
342+
343+
assert self.model is not None
344+
self.model = self.precision_plugin.convert_module(self.model)
345+
self.model = self._setup_model(self.model)
346+
347+
if trainer.state.fn == TrainerFn.FITTING:
348+
self.setup_optimizers(trainer)
342349
self.setup_precision_plugin()
343-
_optimizers_to_device(self.optimizers, self.root_device)
350+
if trainer.state.fn == TrainerFn.FITTING:
351+
_optimizers_to_device(self.optimizers, self.root_device)
352+
344353
self.init_deepspeed()
345354
self.barrier()
346355

@@ -579,15 +588,16 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
579588
trainer: the Trainer, these optimizers should be connected to
580589
581590
"""
582-
if trainer.state.fn != TrainerFn.FITTING:
583-
return
584591
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
585592
# User may have specified config options instead in configure_optimizers, but this is handled
586593
# via `_initialize_deepspeed_train`
587594
# empty optimizers, schedulers
588595
self.optimizers = []
589596
self.lr_scheduler_configs = []
590597

598+
def _setup_model(self, model: Module) -> Module: # type: ignore[override]
599+
return model
600+
591601
@property
592602
@override
593603
def handles_gradient_accumulation(self) -> bool:

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,18 @@ def _setup_model(self, model: Module) -> Module:
314314
@override
315315
def setup(self, trainer: "pl.Trainer") -> None:
316316
assert self.accelerator is not None
317-
assert self.model is not None
318317
self.accelerator.setup(trainer)
319318

319+
assert self.model is not None
320320
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
321321
self.model = self._layer_sync.apply(self.model)
322322

323323
# we set the device so that optimizers can be created with distributed comms.
324324
assert self.lightning_module is not None
325325
self.lightning_module._device = self.root_device
326326

327+
self.model = self.precision_plugin.convert_module(self.model)
328+
327329
if is_overridden("configure_sharded_model", self.lightning_module):
328330
# legacy: we don't skip setup with the `configure_model` alternative
329331
rank_zero_info(
@@ -334,10 +336,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
334336
self.model = self._setup_model(self.model)
335337
self.barrier()
336338

337-
self.setup_optimizers(trainer)
338-
_optimizers_to_device(self.optimizers, self.root_device)
339-
339+
if trainer.state.fn == TrainerFn.FITTING:
340+
self.setup_optimizers(trainer)
340341
self.setup_precision_plugin()
342+
if trainer.state.fn == TrainerFn.FITTING:
343+
_optimizers_to_device(self.optimizers, self.root_device)
341344

342345
@override
343346
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
@@ -370,6 +373,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
370373

371374
@override
372375
def model_to_device(self) -> None:
376+
# FSDP takes care of moving the model to device
373377
pass
374378

375379
@contextmanager

src/lightning/pytorch/strategies/single_device.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,6 @@ def model_to_device(self) -> None:
7878
assert self.model is not None, "self.model must be set before self.model.to()"
7979
self.model.to(self.root_device)
8080

81-
@override
82-
def setup(self, trainer: pl.Trainer) -> None:
83-
self.model_to_device()
84-
super().setup(trainer)
85-
8681
@property
8782
@override
8883
def is_global_zero(self) -> bool:

src/lightning/pytorch/strategies/single_xla.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
2222
from lightning.fabric.plugins import XLACheckpointIO
2323
from lightning.fabric.strategies import _StrategyRegistry
24+
from lightning.fabric.utilities.optimizer import _optimizers_to_device
2425
from lightning.fabric.utilities.types import _DEVICE
2526
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
2627
from lightning.pytorch.plugins.precision.xla import XLAPrecision
2728
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
29+
from lightning.pytorch.trainer.states import TrainerFn
2830
from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
2931

3032

@@ -88,14 +90,26 @@ def precision_plugin(self, precision_plugin: Optional[XLAPrecision]) -> None:
8890

8991
@override
9092
def setup(self, trainer: "pl.Trainer") -> None:
91-
assert self.model, "self.model must be set before find_shared_parameters(self.model)"
93+
if self.debug:
94+
os.environ["PT_XLA_DEBUG"] = str(1)
95+
96+
assert self.accelerator is not None
97+
self.accelerator.setup(trainer)
98+
99+
assert self.model is not None
100+
self.precision_plugin.convert_module(self.model)
101+
92102
shared_params = find_shared_parameters(self.model)
93103
self.model_to_device()
94104
set_shared_parameters(self.model, shared_params)
95-
super().setup(trainer)
96105

97-
if self.debug:
98-
os.environ["PT_XLA_DEBUG"] = str(1)
106+
self.model = self._setup_model(self.model)
107+
108+
if trainer.state.fn == TrainerFn.FITTING:
109+
self.setup_optimizers(trainer)
110+
self.setup_precision_plugin()
111+
if trainer.state.fn == TrainerFn.FITTING:
112+
_optimizers_to_device(self.optimizers, self.root_device)
99113

100114
@classmethod
101115
@override

src/lightning/pytorch/strategies/strategy.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
from abc import ABC, abstractmethod
1616
from contextlib import contextmanager, nullcontext
17-
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union, cast
17+
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
1818

1919
import torch
2020
from torch import Tensor
@@ -110,7 +110,8 @@ def optimizers(self, optimizers: List[Optimizer]) -> None:
110110

111111
def connect(self, model: "pl.LightningModule") -> None:
112112
"""Called by the Trainer to connect the strategy with the model."""
113-
model = cast(pl.LightningModule, self.precision_plugin.convert_module(model))
113+
# model conversions cannot be applied at this point because `LightningModule.{setup,configure_model}` haven't
114+
# run yet
114115
self._lightning_module = model
115116
self.model = model
116117

@@ -134,8 +135,6 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
134135
trainer: the Trainer, these optimizers should be connected to
135136
136137
"""
137-
if trainer.state.fn != TrainerFn.FITTING:
138-
return
139138
assert self.lightning_module is not None
140139
self.optimizers, self.lr_scheduler_configs = _init_optimizers_and_lr_schedulers(self.lightning_module)
141140

@@ -148,9 +147,19 @@ def setup(self, trainer: "pl.Trainer") -> None:
148147
"""
149148
assert self.accelerator is not None
150149
self.accelerator.setup(trainer)
151-
self.setup_optimizers(trainer)
150+
151+
assert self.model is not None
152+
# let the precision plugin convert the module here so that this strategy hook can decide the order
153+
# of operations
154+
self.model = self.precision_plugin.convert_module(self.model)
155+
self.model_to_device()
156+
self.model = self._setup_model(self.model)
157+
158+
if trainer.state.fn == TrainerFn.FITTING:
159+
self.setup_optimizers(trainer)
152160
self.setup_precision_plugin()
153-
_optimizers_to_device(self.optimizers, self.root_device)
161+
if trainer.state.fn == TrainerFn.FITTING:
162+
_optimizers_to_device(self.optimizers, self.root_device)
154163

155164
def setup_precision_plugin(self) -> None:
156165
"""Attaches the precision plugin to the strategy."""

src/lightning/pytorch/strategies/xla.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,20 @@ def _configure_launcher(self) -> None:
137137

138138
@override
139139
def setup(self, trainer: "pl.Trainer") -> None:
140-
assert self.accelerator
140+
assert self.accelerator is not None
141141
self.accelerator.setup(trainer)
142142

143143
if self.debug:
144144
os.environ["PT_XLA_DEBUG"] = "1"
145145

146-
assert self.lightning_module
147-
shared_params = find_shared_parameters(self.lightning_module)
146+
assert self.model is not None
147+
self.precision_plugin.convert_module(self.model)
148+
149+
shared_params = find_shared_parameters(self.model)
148150
self.model_to_device()
151+
set_shared_parameters(self.model, shared_params)
149152

150-
set_shared_parameters(self.lightning_module, shared_params)
151-
self.setup_precision_plugin()
153+
self.model = self._setup_model(self.model)
152154

153155
if self._sync_module_states:
154156
if _XLA_GREATER_EQUAL_2_1:
@@ -160,6 +162,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
160162

161163
if trainer.state.fn == TrainerFn.FITTING:
162164
self.setup_optimizers(trainer)
165+
self.setup_precision_plugin()
166+
if trainer.state.fn == TrainerFn.FITTING:
163167
_optimizers_to_device(self.optimizers, self.root_device)
164168

165169
@override
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
import sys
15+
from unittest.mock import Mock
16+
17+
import lightning.fabric
18+
import pytest
19+
import torch
20+
import torch.distributed
21+
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE
22+
from lightning.pytorch import LightningModule, Trainer
23+
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision
24+
25+
26+
@pytest.mark.skipif(_BITSANDBYTES_AVAILABLE, reason="bitsandbytes needs to be unavailable")
27+
def test_bitsandbytes_plugin(monkeypatch):
28+
module = lightning.fabric.plugins.precision.bitsandbytes
29+
monkeypatch.setattr(module, "_BITSANDBYTES_AVAILABLE", lambda: True)
30+
bitsandbytes_mock = Mock()
31+
monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock)
32+
33+
class ModuleMock(torch.nn.Linear):
34+
def __init__(self, in_features, out_features, bias=True, *_, **__):
35+
super().__init__(in_features, out_features, bias)
36+
37+
bitsandbytes_mock.nn.Linear8bitLt = ModuleMock
38+
bitsandbytes_mock.nn.Linear4bit = ModuleMock
39+
bitsandbytes_mock.nn.Params4bit = object
40+
41+
precision = BitsandbytesPrecision("nf4", dtype=torch.float16)
42+
trainer = Trainer(barebones=True, plugins=precision)
43+
44+
_NF4Linear = vars(module)["_NF4Linear"]
45+
quantize_mock = lambda self, p, w, d: p
46+
_NF4Linear.quantize = quantize_mock
47+
48+
class MyModel(LightningModule):
49+
def configure_model(self):
50+
self.l = torch.nn.Linear(1, 3)
51+
52+
def test_step(self, *_):
53+
...
54+
55+
model = MyModel()
56+
trainer.test(model, [0])
57+
assert isinstance(model.l, _NF4Linear)

tests/tests_pytorch/plugins/precision/test_half.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616
import torch
17+
from lightning.pytorch import LightningModule, Trainer
1718
from lightning.pytorch.plugins import HalfPrecision
1819

1920

@@ -73,3 +74,26 @@ def test_convert_module(precision, expected_dtype):
7374
assert module.weight.dtype == module.bias.dtype == torch.float32
7475
module = precision.convert_module(module)
7576
assert module.weight.dtype == module.bias.dtype == expected_dtype
77+
78+
79+
@pytest.mark.parametrize(
80+
("precision", "expected_dtype"),
81+
[
82+
("bf16-true", torch.bfloat16),
83+
("16-true", torch.half),
84+
],
85+
)
86+
def test_configure_model(precision, expected_dtype):
87+
class MyModel(LightningModule):
88+
def configure_model(self):
89+
self.l = torch.nn.Linear(1, 3)
90+
# this is under the `module_init_context`
91+
assert self.l.weight.dtype == expected_dtype
92+
93+
def test_step(self, *_):
94+
...
95+
96+
model = MyModel()
97+
trainer = Trainer(barebones=True, precision=precision)
98+
trainer.test(model, [0])
99+
assert model.l.weight.dtype == expected_dtype

0 commit comments

Comments
 (0)