Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .initialize import init_dist, initialize
from .nn import *
from .initialize import (initialize, launch, launch_from_openmpi,
launch_from_slurm, launch_from_torch, get_default_parser)

__version__ = '0.0.1'
32 changes: 32 additions & 0 deletions colossalai/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from .amp_type import AMP_TYPE
from colossalai.context import Config
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from .torch_amp import convert_to_torch_amp
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp


def convert_to_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
mode: AMP_TYPE,
amp_config: Config = None):
assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'

if amp_config is None:
amp_config = Config()

if mode == AMP_TYPE.TORCH:
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
elif mode == AMP_TYPE.APEX:
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
elif mode == AMP_TYPE.NAIVE:
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)

return model, optimizer, criterion
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
class AMP_TYPE(Enum):
APEX = 'apex'
TORCH = 'torch'
PARALLEL = 'parallel'
NAIVE = 'naive'
15 changes: 15 additions & 0 deletions colossalai/amp/apex_amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .apex_amp import ApexAMPOptimizer
import torch.nn as nn
from torch.optim import Optimizer
import apex.amp as apex_amp


def convert_to_apex_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer


__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
23 changes: 23 additions & 0 deletions colossalai/amp/apex_amp/apex_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch.nn as nn
try:
import apex.amp as apex_amp
except:
pass
from torch import Tensor

from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32


class ApexAMPOptimizer(ColossalaiOptimizer):

def backward(self, loss: Tensor):
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()

def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
20 changes: 20 additions & 0 deletions colossalai/amp/naive_amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage

from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel


def convert_to_naive_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
if is_no_pp_or_last_stage():
model = NaiveAMPModel(model, output_to_fp32=True)
else:
model = NaiveAMPModel(model, output_to_fp32=False)

optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
return model, optimizer


__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.registry import OPTIMIZER_WRAPPERS
from colossalai.utils import print_rank_0
from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32
from ..multi_tensor_apply import multi_tensor_applier
from colossalai.logging import get_dist_logger
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)


def _zero_grad_group_helper(group, set_to_none):
Expand Down Expand Up @@ -92,7 +90,7 @@ def __init__(self,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis

self._logger = get_global_dist_logger()
self._logger = get_dist_logger()

@property
def scale(self):
Expand Down Expand Up @@ -145,7 +143,6 @@ def load_state_dict(self, state_dict):
self._max_scale = state_dict['max_scale']


@OPTIMIZER_WRAPPERS.register_module
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.

Expand Down Expand Up @@ -184,13 +181,13 @@ def __init__(self,
max_scale: int = 2 ** 32):
# default args for compatibility
bf16 = False
params_have_main_grad = False
params_have_main_grad = True

# have a defaults for compatibility with pytorch optim
self.defaults = optimizer.defaults

# log config
self._logger = get_global_dist_logger()
self._logger = get_dist_logger()
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad = {clip_grad}\n"
Expand Down Expand Up @@ -328,6 +325,7 @@ def _copy_model_grads_to_main_grads(self):
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()

# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
Expand Down Expand Up @@ -387,10 +385,6 @@ def reload_model_params(self):

@torch.no_grad()
def step(self):
# for param_group in self.float16_groups:
# for param in param_group:
# print(param.grad is None)

# Copy gradients from model params to main params.
self._copy_model_grads_to_main_grads()

Expand Down
65 changes: 65 additions & 0 deletions colossalai/amp/naive_amp/naive_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, List, Any, Dict
from torch.optim import Optimizer
import torch.cuda.amp as torch_amp

from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer


class NaiveAMPOptimizer(ColossalaiOptimizer):

def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
super().__init__(optim)

def backward(self, loss: Tensor):
loss = self.optim.scale_loss(loss)
loss.backward()

def step(self):
self.optim.step()

def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass


class NaiveAMPModel(nn.Module):

def __init__(self,
model: nn.Module,
output_to_fp32: bool = True):
super().__init__()
self.model = model.half()
self._output_to_fp32 = output_to_fp32

def _convert_to_fp16(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
input_ = input_.half()
return input_

def _convert_to_fp32(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
input_ = input_.float()
return input_

def forward(self, *args, **kwargs):
if args:
args = [self._convert_to_fp16(arg) for arg in args]
if kwargs:
for k, v in kwargs.items():
kwargs[k] = self._convert_to_fp16(v)

out = self.model(*args, **kwargs)

if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
return out
18 changes: 18 additions & 0 deletions colossalai/amp/torch_amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from colossalai.context import Config
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss


def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
amp_config: Config):
model = TorchAMPModel(model)
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
criterion = TorchAMPLoss(criterion)
return model, optimizer, criterion


__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
# to support tensor parallel

import torch
from collections import defaultdict, abc
import warnings
Expand Down
54 changes: 54 additions & 0 deletions colossalai/amp/torch_amp/torch_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch.nn as nn
import torch.cuda.amp as torch_amp

from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from ._grad_scaler import GradScaler

from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32


class TorchAMPOptimizer(ColossalaiOptimizer):

def __init__(self, optim: Optimizer, *args, **kwargs):
super().__init__(optim)
self.scaler = GradScaler(*args, **kwargs)

def backward(self, loss: Tensor):
self.scaler.scale(loss).backward()

def step(self):
self.scaler.step(self.optim)
self.scaler.update()

def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0.0:
self.scaler.unscale_(self.optim)
clip_grad_norm_fp32(model.parameters(), max_norm)


class TorchAMPModel(nn.Module):

def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model

@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)


class TorchAMPLoss(nn.Module):

def __init__(self, loss: _Loss):
super().__init__()
self.loss = loss

@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.loss(*args, **kwargs)
10 changes: 5 additions & 5 deletions colossalai/builder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
from .pipeline import ModelInitializer
from .pipeline import PipelineModelInitializer

__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'ModelInitializer'
'build_gradient_handler', 'PipelineModelInitializer'
]
Loading