-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
fsdp.py
177 lines (148 loc) · 7.61 KB
/
fsdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Optional
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from typing_extensions import get_args, override
import lightning.pytorch as pl
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities.exceptions import MisconfigurationException
if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
class FSDPPrecision(Precision):
"""Precision plugin for training with Fully Sharded Data Parallel (FSDP).
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Args:
precision: Full precision (32-true), half precision (16-true, bf16-true) or
mixed precision (16-mixed, bf16-mixed).
scaler: An optional :class:`torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` to use.
Raises:
ValueError:
If unsupported ``precision`` is provided.
"""
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
if scaler is not None and self.precision != "16-mixed":
raise ValueError(f"`precision={precision!r}` does not use a scaler, found {scaler}.")
self.scaler = ShardedGradScaler() if scaler is None and precision == "16-mixed" else None
self.precision = precision
precision_to_type = {
"bf16-mixed": torch.float32,
"16-mixed": torch.float32,
"bf16-true": torch.bfloat16,
"16-true": torch.float16,
"32-true": torch.float32,
}
self._desired_input_dtype = precision_to_type[self.precision]
@override
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)
@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
if self.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
return TorchMixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
)
@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)
@override
def module_init_context(self) -> ContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
@override
def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return _DtypeContextManager(self._desired_input_dtype)
@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
@override
def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override]
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
return super().pre_backward(tensor, module)
@override
def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
model: "pl.LightningModule",
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
closure_result = closure()
if not _optimizer_handles_unscaling(optimizer):
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer) # type: ignore[arg-type]
self._after_closure(model, optimizer)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type]
self.scaler.update()
return step_output
return closure_result
@override
def state_dict(self) -> Dict[str, Any]:
if self.scaler is not None:
return self.scaler.state_dict()
return {}
@override
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)