Skip to content

Commit fa1e21b

Browse files
committed
fix FP16 optimizer and adapted torch amp with tensor parallel
1 parent 0aa07e6 commit fa1e21b

File tree

12 files changed

+618
-18
lines changed

12 files changed

+618
-18
lines changed

colossalai/engine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .amp_type import AMP_TYPE
21
from ._base_engine import Engine
32
from .gradient_handler import *
43
from .schedule import *
4+
from .amp import *
55

66

77
__all__ = ['Engine']

colossalai/engine/amp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .grad_scaler import GradScaler
2+
from .amp_type import AMP_TYPE
File renamed without changes.

colossalai/engine/amp/grad_scaler.py

Lines changed: 577 additions & 0 deletions
Large diffs are not rendered by default.

colossalai/engine/schedule/_no_pipeline.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212

1313
from colossalai.context import ParallelMode
1414
from colossalai.core import global_context as gpc
15-
from colossalai.engine.amp_type import AMP_TYPE
1615
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
1716
ZeroRedundancyOptimizer_Level_3)
17+
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
1818
from ._utils import convert_to_fp16
1919
from ._base_schedule import BaseSchedule
20+
from ..amp import AMP_TYPE, GradScaler
2021

2122

2223
class NoPipelineSchedule(BaseSchedule):
@@ -30,6 +31,7 @@ class NoPipelineSchedule(BaseSchedule):
3031
:type amp_type: AMP_TYPE
3132
:type amp_config: dict
3233
"""
34+
3335
def __init__(
3436
self,
3537
amp_type: AMP_TYPE = None,
@@ -101,7 +103,7 @@ def initialize(self,
101103

102104
if self.fp16:
103105
if self.amp_type == AMP_TYPE.TORCH:
104-
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
106+
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
105107
elif self.amp_type == AMP_TYPE.APEX:
106108
self.model, self.optimizer = apex_amp.initialize(
107109
self.model, self.optimizer, **self.amp_cfg)
@@ -175,9 +177,16 @@ def forward_backward_step(self, forward_only=False, return_loss=True):
175177
def step(self):
176178
# step optimizer
177179
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
180+
if getattr(gpc.config, 'clip_grad', 0.0) > 0.0:
181+
self._torch_amp_scaler.unscale_(self.optimizer)
182+
clip_grad_norm_fp32(self.model.parameters(),
183+
gpc.config.clip_grad)
178184
self._torch_amp_scaler.step(self.optimizer)
179185
self._torch_amp_scaler.update()
180186
else:
187+
if not self.fp16 and not self.use_zero_level_2_3 and getattr(gpc.config, 'clip_grad', 0.0) > 0.0:
188+
clip_grad_norm_fp32(self.model.parameters(),
189+
gpc.config.clip_grad)
181190
self.optimizer.step()
182191

183192
# update lr scheduler

colossalai/engine/schedule/_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from colossalai.utils import get_current_device
1616
from ._base_schedule import BaseSchedule
1717
from ._utils import convert_to_fp16
18-
from ..amp_type import AMP_TYPE
18+
from ..amp import AMP_TYPE
1919

2020

2121
def squeeze(x: Union[Tensor, tuple, list]):
@@ -163,7 +163,7 @@ def forward_step(self, input_tensor, return_tensors, return_loss=True):
163163
if return_loss:
164164
input_tensor, label = self.load_micro_batch()
165165
loss_reduced = self.criterion(output_tensor, *
166-
label) / self.num_microbatches
166+
label) / self.num_microbatches
167167
return_tensors.append(
168168
tuple((output_tensor, label[0], loss_reduced)))
169169
return loss_reduced
@@ -200,7 +200,7 @@ def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
200200
def forward_backward_step(self, forward_only=True, return_loss=True):
201201
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
202202
Returns a tuple with losses if the last stage, an empty tuple otherwise.
203-
203+
204204
:return: (output, label, loss)
205205
"""
206206

colossalai/nn/optimizer/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
106106
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
107107
no_tensor_parallel_grads = _calc_lp(
108108
no_tensor_parallel_grads, norm_type)
109-
if gpc.is_initialized(ParallelMode.TENSOR):
109+
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
110110
# Sum across all model-parallel GPUs.
111111
torch.distributed.all_reduce(tensor_parallel_norm,
112112
op=torch.distributed.ReduceOp.SUM,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
colossalai.engine.amp.amp\_type
2+
===============================
3+
4+
.. automodule:: colossalai.engine.amp.amp_type
5+
:members:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
colossalai.engine.amp.grad\_scaler
2+
==================================
3+
4+
.. automodule:: colossalai.engine.amp.grad_scaler
5+
:members:
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
colossalai.engine.amp
2+
=====================
3+
4+
.. automodule:: colossalai.engine.amp
5+
:members:
6+
7+
8+
.. toctree::
9+
:maxdepth: 2
10+
11+
colossalai.engine.amp.amp_type
12+
colossalai.engine.amp.grad_scaler

0 commit comments

Comments
 (0)