|
16 | 16 | import torch.nn as nn |
17 | 17 | from torch.optim import Optimizer |
18 | 18 |
|
| 19 | +<<<<<<< HEAD |
19 | 20 | from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, |
20 | 21 | ZeroRedundancyOptimizer_Level_3) |
21 | 22 | from colossalai.nn.optimizer._utils import clip_grad_norm_fp32 |
22 | 23 | from ._base_schedule import BaseSchedule |
23 | 24 | from ._utils import convert_to_fp16, convert_to_fp32 |
| 25 | +======= |
| 26 | +from colossalai.context import ParallelMode |
| 27 | +from colossalai.core import global_context as gpc |
| 28 | +from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, |
| 29 | + ZeroRedundancyOptimizer_Level_3) |
| 30 | +from colossalai.nn.optimizer._utils import clip_grad_norm_fp32 |
| 31 | +from ._utils import convert_to_fp16 |
| 32 | +from ._base_schedule import BaseSchedule |
| 33 | +>>>>>>> c8cb9f9... fix FP16 optimizer and adapted torch amp with tensor parallel (#18) |
24 | 34 | from ..amp import AMP_TYPE, GradScaler |
25 | 35 |
|
26 | 36 |
|
@@ -191,10 +201,14 @@ def forward_backward_step(self, |
191 | 201 | def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0): |
192 | 202 | # step optimizer |
193 | 203 | if self.fp16 and self.amp_type == AMP_TYPE.TORCH: |
194 | | - if grad_clipping > 0.0: |
195 | | - self._torch_amp_scaler.unscale_(optimizer) |
196 | | - clip_grad_norm_fp32(model.parameters(), grad_clipping) |
197 | | - self._torch_amp_scaler.step(optimizer) |
| 204 | + if getattr(gpc.config, 'clip_grad', 0.0) > 0.0: |
| 205 | + self._torch_amp_scaler.unscale_(self.optimizer) |
| 206 | + clip_grad_norm_fp32(self.model.parameters(), |
| 207 | + gpc.config.clip_grad) |
| 208 | + self._torch_amp_scaler.step(self.optimizer) |
198 | 209 | self._torch_amp_scaler.update() |
199 | 210 | else: |
| 211 | + if not self.fp16 and not self.use_zero_level_2_3 and getattr(gpc.config, 'clip_grad', 0.0) > 0.0: |
| 212 | + clip_grad_norm_fp32(self.model.parameters(), |
| 213 | + gpc.config.clip_grad) |
200 | 214 | self.optimizer.step() |
0 commit comments