Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .amp_type import AMP_TYPE
from ._base_engine import Engine
from .gradient_handler import *
from .schedule import *
from .amp import *


__all__ = ['Engine']
2 changes: 2 additions & 0 deletions colossalai/engine/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .grad_scaler import GradScaler
from .amp_type import AMP_TYPE
File renamed without changes.
577 changes: 577 additions & 0 deletions colossalai/engine/amp/grad_scaler.py

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions colossalai/engine/schedule/_no_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.amp_type import AMP_TYPE
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._utils import convert_to_fp16
from ._base_schedule import BaseSchedule
from ..amp import AMP_TYPE, GradScaler


class NoPipelineSchedule(BaseSchedule):
Expand All @@ -30,6 +31,7 @@ class NoPipelineSchedule(BaseSchedule):
:type amp_type: AMP_TYPE
:type amp_config: dict
"""

def __init__(
self,
amp_type: AMP_TYPE = None,
Expand Down Expand Up @@ -102,7 +104,7 @@ def initialize(self,

if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
self.model, self.optimizer = apex_amp.initialize(
self.model, self.optimizer, **self.amp_cfg)
Expand Down Expand Up @@ -177,7 +179,14 @@ def forward_backward_step(self, forward_only=False, return_loss=True):
def step(self):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
if getattr(gpc.config, 'clip_grad', 0.0) > 0.0:
self._torch_amp_scaler.unscale_(self.optimizer)
clip_grad_norm_fp32(self.model.parameters(),
gpc.config.clip_grad)
self._torch_amp_scaler.step(self.optimizer)
self._torch_amp_scaler.update()
else:
if not self.fp16 and not self.use_zero_level_2_3 and getattr(gpc.config, 'clip_grad', 0.0) > 0.0:
clip_grad_norm_fp32(self.model.parameters(),
gpc.config.clip_grad)
self.optimizer.step()
7 changes: 3 additions & 4 deletions colossalai/engine/schedule/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp_type import AMP_TYPE
from ..amp import AMP_TYPE


def squeeze(x: Union[Tensor, tuple, list]):
Expand Down Expand Up @@ -163,8 +163,7 @@ def forward_step(self, input_tensor, return_tensors, return_loss=True):
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / (self.num_microbatches * self.grad_accum)
loss_reduced = self.criterion(output_tensor, *label) / (self.num_microbatches * self.grad_accum)
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
Expand Down Expand Up @@ -201,7 +200,7 @@ def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
def forward_backward_step(self, forward_only=True, return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.

:return: (output, label, loss)
"""

Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/optimizer/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp(
no_tensor_parallel_grads, norm_type)
if gpc.is_initialized(ParallelMode.TENSOR):
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(tensor_parallel_norm,
op=torch.distributed.ReduceOp.SUM,
Expand Down
5 changes: 5 additions & 0 deletions docs/colossalai/colossalai.engine.amp.amp_type.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
colossalai.engine.amp.amp\_type
===============================

.. automodule:: colossalai.engine.amp.amp_type
:members:
5 changes: 5 additions & 0 deletions docs/colossalai/colossalai.engine.amp.grad_scaler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
colossalai.engine.amp.grad\_scaler
==================================

.. automodule:: colossalai.engine.amp.grad_scaler
:members:
12 changes: 12 additions & 0 deletions docs/colossalai/colossalai.engine.amp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
colossalai.engine.amp
=====================

.. automodule:: colossalai.engine.amp
:members:


.. toctree::
:maxdepth: 2

colossalai.engine.amp.amp_type
colossalai.engine.amp.grad_scaler
5 changes: 0 additions & 5 deletions docs/colossalai/colossalai.engine.amp_type.rst

This file was deleted.

7 changes: 1 addition & 6 deletions docs/colossalai/colossalai.engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ colossalai.engine
.. toctree::
:maxdepth: 2

colossalai.engine.amp
colossalai.engine.gradient_handler
colossalai.engine.schedule


.. toctree::
:maxdepth: 2

colossalai.engine.amp_type