Skip to content

Commit efdc968

Browse files
committed
update api for better usability
1 parent 07f3e98 commit efdc968

File tree

194 files changed

+3280
-8008
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

194 files changed

+3280
-8008
lines changed

colossalai/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .initialize import init_dist, initialize
2-
from .nn import *
1+
from .initialize import (init_dist, initialize, init_dist_from_slurm,
2+
init_dist_from_openmpi, init_dist_from_torch, get_default_parser)
33

44
__version__ = '0.0.1'

colossalai/amp/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
from .amp_type import AMP_TYPE
5+
from colossalai.context import Config
6+
import torch.nn as nn
7+
from torch.optim import Optimizer
8+
from torch.nn.modules.loss import _Loss
9+
from .torch_amp import convert_to_torch_amp
10+
from .apex_amp import convert_to_apex_amp
11+
from .naive_amp import convert_to_naive_amp
12+
13+
14+
def convert_to_amp(model: nn.Module,
15+
optimizer: Optimizer,
16+
criterion: _Loss,
17+
mode: AMP_TYPE,
18+
amp_config: Config = None):
19+
assert isinstance(mode, AMP_TYPE), \
20+
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
21+
22+
if amp_config is None:
23+
amp_config = Config()
24+
25+
if mode == AMP_TYPE.TORCH:
26+
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
27+
elif mode == AMP_TYPE.APEX:
28+
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
29+
elif mode == AMP_TYPE.NAIVE:
30+
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
31+
32+
return model, optimizer, criterion

colossalai/engine/amp/amp_type.py renamed to colossalai/amp/amp_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
class AMP_TYPE(Enum):
88
APEX = 'apex'
99
TORCH = 'torch'
10-
PARALLEL = 'parallel'
10+
NAIVE = 'naive'
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from .apex_amp import ApexAMPOptimizer
2+
import torch.nn as nn
3+
from torch.optim import Optimizer
4+
import apex.amp as apex_amp
5+
6+
7+
def convert_to_apex_amp(model: nn.Module,
8+
optimizer: Optimizer,
9+
amp_config):
10+
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
11+
optimizer = ApexAMPOptimizer(optimizer)
12+
return model, optimizer
13+
14+
15+
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
import torch.nn as nn
5+
try:
6+
import apex.amp as apex_amp
7+
except:
8+
pass
9+
from torch import Tensor
10+
11+
from colossalai.nn.optimizer import ColossalaiOptimizer
12+
from colossalai.utils import clip_grad_norm_fp32
13+
14+
15+
class ApexAMPOptimizer(ColossalaiOptimizer):
16+
17+
def backward(self, loss: Tensor):
18+
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
19+
scaled_loss.backward()
20+
21+
def clip_grad_norm(self, model: nn.Module, max_norm: float):
22+
if max_norm > 0:
23+
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch.nn as nn
2+
from torch.optim import Optimizer
3+
from colossalai.utils import is_no_pp_or_last_stage
4+
5+
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
6+
7+
8+
def convert_to_naive_amp(model: nn.Module,
9+
optimizer: Optimizer,
10+
amp_config):
11+
if is_no_pp_or_last_stage():
12+
model = NaiveAMPModel(model, output_to_fp32=True)
13+
else:
14+
model = NaiveAMPModel(model, output_to_fp32=False)
15+
16+
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
17+
return model, optimizer
18+
19+
20+
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']

colossalai/nn/optimizer/fp16_optimizer.py renamed to colossalai/amp/naive_amp/_fp16_optimizer.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212

1313
from colossalai.context.parallel_mode import ParallelMode
1414
from colossalai.core import global_context as gpc
15-
from colossalai.logging import get_global_dist_logger
16-
from colossalai.registry import OPTIMIZER_WRAPPERS
17-
from colossalai.utils import print_rank_0
18-
from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32
19-
from ..multi_tensor_apply import multi_tensor_applier
15+
from colossalai.logging import get_dist_logger
16+
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
17+
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
2018

2119

2220
def _zero_grad_group_helper(group, set_to_none):
@@ -92,7 +90,7 @@ def __init__(self,
9290
self._growth_tracker = 0
9391
self._hysteresis_tracker = self.hysteresis
9492

95-
self._logger = get_global_dist_logger()
93+
self._logger = get_dist_logger()
9694

9795
@property
9896
def scale(self):
@@ -145,7 +143,6 @@ def load_state_dict(self, state_dict):
145143
self._max_scale = state_dict['max_scale']
146144

147145

148-
@OPTIMIZER_WRAPPERS.register_module
149146
class FP16Optimizer(Optimizer):
150147
"""Float16 optimizer for fp16 and bf16 data types.
151148
@@ -184,13 +181,13 @@ def __init__(self,
184181
max_scale: int = 2 ** 32):
185182
# default args for compatibility
186183
bf16 = False
187-
params_have_main_grad = False
184+
params_have_main_grad = True
188185

189186
# have a defaults for compatibility with pytorch optim
190187
self.defaults = optimizer.defaults
191188

192189
# log config
193-
self._logger = get_global_dist_logger()
190+
self._logger = get_dist_logger()
194191
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
195192
f"Optimizer: {optimizer.__class__.__name__}\n"
196193
f"clip_grad = {clip_grad}\n"
@@ -328,6 +325,7 @@ def _copy_model_grads_to_main_grads(self):
328325
else:
329326
if model_param.grad is not None:
330327
main_param.grad = model_param.grad.float()
328+
331329
# For fp32 grads, we need to reset the grads to main grad.
332330
if self.params_have_main_grad:
333331
for model_group in self.fp32_from_fp32_groups:
@@ -387,10 +385,6 @@ def reload_model_params(self):
387385

388386
@torch.no_grad()
389387
def step(self):
390-
# for param_group in self.float16_groups:
391-
# for param in param_group:
392-
# print(param.grad is None)
393-
394388
# Copy gradients from model params to main params.
395389
self._copy_model_grads_to_main_grads()
396390

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch import Tensor
7+
from typing import Union, List, Any, Dict
8+
from torch.optim import Optimizer
9+
import torch.cuda.amp as torch_amp
10+
11+
from colossalai.nn.optimizer import ColossalaiOptimizer
12+
from ._fp16_optimizer import FP16Optimizer
13+
14+
15+
class NaiveAMPOptimizer(ColossalaiOptimizer):
16+
17+
def __init__(self, optim: Optimizer, *args, **kwargs):
18+
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
19+
super().__init__(optim)
20+
21+
def backward(self, loss: Tensor):
22+
loss = self.optim.scale_loss(loss)
23+
loss.backward()
24+
25+
def step(self):
26+
self.optim.step()
27+
28+
def clip_grad_norm(self, model: nn.Module, max_norm: float):
29+
pass
30+
31+
32+
class NaiveAMPModel(nn.Module):
33+
34+
def __init__(self,
35+
model: nn.Module,
36+
output_to_fp32: bool = True):
37+
super().__init__()
38+
self.model = model.half()
39+
self._output_to_fp32 = output_to_fp32
40+
41+
def _convert_to_fp16(self, input_: Any):
42+
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
43+
input_ = input_.half()
44+
return input_
45+
46+
def _convert_to_fp32(self, input_: Any):
47+
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
48+
input_ = input_.float()
49+
return input_
50+
51+
def forward(self, *args, **kwargs):
52+
if args:
53+
args = [self._convert_to_fp16(arg) for arg in args]
54+
if kwargs:
55+
for k, v in kwargs.items():
56+
kwargs[k] = self._convert_to_fp16(v)
57+
58+
out = self.model(*args, **kwargs)
59+
60+
if self._output_to_fp32:
61+
if isinstance(out, Tensor):
62+
out = self._convert_to_fp32(out)
63+
elif isinstance(out, (tuple, list)):
64+
out = [self._convert_to_fp32(val) for val in out]
65+
return out
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch.nn as nn
2+
from torch.optim import Optimizer
3+
from torch.nn.modules.loss import _Loss
4+
from colossalai.context import Config
5+
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
6+
7+
8+
def convert_to_torch_amp(model: nn.Module,
9+
optimizer: Optimizer,
10+
criterion: _Loss,
11+
amp_config: Config):
12+
model = TorchAMPModel(model)
13+
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
14+
criterion = TorchAMPLoss(criterion)
15+
return model, optimizer, criterion
16+
17+
18+
__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']

colossalai/engine/amp/grad_scaler.py renamed to colossalai/amp/torch_amp/_grad_scaler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
4+
# to support tensor parallel
5+
26
import torch
37
from collections import defaultdict, abc
48
import warnings

0 commit comments

Comments
 (0)