Skip to content

Commit 678c92d

Browse files
committed
improved consistency between trainer, engine and schedule
1 parent 7df468a commit 678c92d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1121
-1106
lines changed

README.md

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
4242

4343
```python
4444
import colossalai
45-
from colossalai.engine import Engine
4645
from colossalai.trainer import Trainer
4746
from colossalai.core import global_context as gpc
4847

49-
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
50-
engine = Engine(
51-
model=model,
52-
criterion=criterion,
53-
optimizer=optimizer,
54-
lr_scheduler=lr_scheduler,
55-
schedule=schedule
56-
)
48+
engine, train_dataloader, test_dataloader = colossalai.initialize()
5749

5850
trainer = Trainer(engine=engine,
59-
hooks_cfg=gpc.config.hooks,
6051
verbose=True)
6152
trainer.fit(
6253
train_dataloader=train_dataloader,
6354
test_dataloader=test_dataloader,
64-
max_epochs=gpc.config.num_epochs,
55+
epochs=gpc.config.num_epochs,
56+
hooks_cfg=gpc.config.hooks,
6557
display_progress=True,
6658
test_interval=5
6759
)

colossalai/builder/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
1-
from .builder import *
1+
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
2+
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
3+
build_gradient_handler)
24
from .pipeline import ModelInitializer
5+
6+
__all__ = [
7+
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
8+
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
9+
'build_gradient_handler', 'ModelInitializer'
10+
]

colossalai/builder/builder.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,6 @@ def build_transform(config):
181181
return build_from_registry(config, TRANSFORMS)
182182

183183

184-
def build_pipe_alloc_policy(config):
185-
"""Returns a pipeline allocation policy object constructed from `config`.
186-
187-
:param config: A python dict or a :class:`colossalai.context.Config` object
188-
containing information used in the construction of the return object
189-
:type config: dict or :class:`colossalai.context.Config`
190-
:return: A pipeline allocation policy object
191-
:rtype:
192-
"""
193-
return build_from_registry(config, PIPE_ALLOC_POLICY)
194-
195-
196184
def build_data_sampler(config, dataset):
197185
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
198186
constructed from `config`.
@@ -254,8 +242,16 @@ def build_lr_scheduler(config, optimizer):
254242
"""
255243
config_ = config.copy()
256244
mod_type = config_.pop('type')
257-
# warmup epochs will overwrite warmup steps
258-
# if 'warmup_epochs' in config_:
259-
# warmup_epochs = config_.pop('warmup_epochs')
260-
# config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
261245
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
246+
247+
248+
def build_schedule(config):
249+
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
250+
251+
:param config: A python dict or a :class:`colossalai.context.Config` object
252+
containing information used in the construction of the return object
253+
:type config: dict or :class:`colossalai.context.Config`
254+
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule`
255+
:rtype: :class:`colossalai.engine.schedule.BaseSchedule`
256+
"""
257+
return build_from_registry(config, SCHEDULE)

colossalai/engine/_base_engine.py

Lines changed: 79 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
#!/usr/bin/env python
22
# -*- encoding: utf-8 -*-
33

4-
from typing import Optional
4+
from torch.nn import Module
5+
from torch.nn.modules.loss import _Loss
6+
from torch.optim import Optimizer
57

68
from colossalai.builder import build_gradient_handler
79
from colossalai.context import ParallelMode
810
from colossalai.core import global_context as gpc
911
from colossalai.logging import get_global_dist_logger
1012
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
1113
ZeroRedundancyOptimizer_Level_3)
12-
from torch.nn import Module
13-
from torch.nn.modules.loss import _Loss
14-
from torch.optim import Optimizer
15-
16-
from .schedule import BaseSchedule, NoPipelineSchedule
14+
from .schedule import BaseSchedule
1715

1816

1917
class Engine:
@@ -36,49 +34,80 @@ class Engine:
3634
def __init__(self,
3735
model: Module,
3836
optimizer: Optimizer,
39-
step_schedule: BaseSchedule = None,
37+
criterion: _Loss,
38+
step_schedule: BaseSchedule,
39+
gradient_handlers: list = None,
4040
gradient_accumulation: int = 1,
41-
gradient_clipping: float = 0.0):
42-
self.schedule = step_schedule if step_schedule is not None \
43-
else NoPipelineSchedule()
44-
self.schedule.initialize(model, optimizer)
45-
self.grad_accum_size = gradient_accumulation
46-
self.grad_accum_cur_step = 0
47-
self.grad_clip = gradient_clipping
41+
gradient_clipping: float = 0.0,
42+
):
43+
self._model = model
44+
self._optimizer = optimizer
45+
self._criterion = criterion
46+
self._schedule = step_schedule
47+
48+
# schedule initialize
49+
self._schedule.initialize(model, optimizer)
50+
51+
# state
4852
self.training = True # default
53+
54+
# gradient accumulation
55+
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
56+
self._grad_accum_size = gradient_accumulation
57+
self._grad_clip = gradient_clipping
4958
self._logger = get_global_dist_logger()
5059

5160
# build gradient handler
5261
self._gradient_handlers = []
53-
gradient_handler_cfg = []
5462

55-
if hasattr(gpc.config, 'gradient_handler'):
56-
assert isinstance(gpc.config.gradient_handler, list), \
63+
if gradient_handlers is not None:
64+
assert isinstance(gradient_handlers, list), \
5765
f'argument gradient_handler_cfg expected type list, ' \
58-
f'but got type {type(gpc.config.gradient_handler)}'
59-
gradient_handler_cfg = gpc.config.gradient_handler
66+
f'but got type {type(gradient_handlers)}'
6067
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
6168
ZeroRedundancyOptimizer_Level_3)):
62-
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
69+
gradient_handlers = [dict(type='ZeROGradientHandler')]
6370
self._logger.info(
6471
"Training with zero is detected, ZeROGradientHandler is automatically "
6572
"added even though not specified in the configuration",
6673
ranks=[0])
6774
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
6875
ParallelMode.DATA) > 1:
69-
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
76+
gradient_handlers = [dict(type='DataParallelGradientHandler')]
7077
self._logger.info(
7178
"Data parallel training is detected, DataParallelGradientHandler is automatically "
7279
"added even though not specified in the configuration",
7380
ranks=[0])
74-
if len(gradient_handler_cfg) == 0:
81+
82+
if gradient_handlers is None:
7583
self._logger.warning(
7684
"No gradient handler is set up, please make sure you do not need "
7785
"to all-reduce the gradients after a training step.",
7886
ranks=[0])
79-
for cfg in gradient_handler_cfg:
80-
handler = build_gradient_handler(cfg, model, optimizer)
81-
self._gradient_handlers.append(handler)
87+
else:
88+
for cfg in gradient_handlers:
89+
handler = build_gradient_handler(cfg, model, optimizer)
90+
self._gradient_handlers.append(handler)
91+
92+
@property
93+
def model(self):
94+
return self._model
95+
96+
@property
97+
def optimizer(self):
98+
return self._optimizer
99+
100+
@property
101+
def criterion(self):
102+
return self._criterion
103+
104+
@property
105+
def schedule(self):
106+
return self._schedule
107+
108+
@property
109+
def gradient_accumulation(self):
110+
return self._grad_accum_size
82111

83112
def handle_gradient(self):
84113
"""Handles all-reduce operations of gradients across different parallel groups.
@@ -90,57 +119,57 @@ def train(self):
90119
"""Sets the model to training mode.
91120
"""
92121
self.training = True
122+
self._model.train()
93123

94124
def eval(self):
95125
"""Sets the model to evaluation mode.
96126
"""
97127
self.training = False
128+
self._model.eval()
98129

99130
def step(self,
100131
data_iter,
101-
model: Module,
102-
criterion: _Loss,
103-
optimizer: Optimizer = None,
104132
is_last_iteration: bool = False,
105133
return_loss=True):
106134
"""A running step based on the schedule. Usually, it runs a training or
107135
evaluation over a batch of dataset.
108136
109137
:param data_iter: Data iterator of the dataset
110-
:param model: The neural network model
111-
:param criterion: Loss function used to calculate
112-
:param optimizer: Optimizer for updating the parameters
113138
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
114139
:param return_loss: loss will be returned if True
115140
:type data_iter: Iterator
116-
:type model: Module
117-
:type criterion: _Loss
118-
:type optimizer: Optimizer, optional
119141
:type is_last_iteration: bool, optional
120142
:type return_loss: bool, optional
121143
:return: (output, lablel, loss)
122144
"""
123-
if self.training and self.grad_accum_cur_step == 0:
124-
optimizer.zero_grad()
125-
126-
output, label, loss = self.schedule.forward_backward_step(
127-
data_iter, model, criterion, optimizer,
128-
forward_only=not self.training,
129-
grad_accum_size=self.grad_accum_size,
130-
return_loss=return_loss)
131-
132145
if self.training:
133-
self.grad_accum_cur_step += 1
134-
if self.grad_accum_cur_step == self.grad_accum_size:
135-
# all reduce gradients
136-
self.handle_gradient()
137-
self.schedule.optimizer_step(model, optimizer, self.grad_clip)
138-
self.grad_accum_cur_step = 0
146+
self._optimizer.zero_grad()
139147

148+
# differentiate training and eval with grad accum
149+
if self.training:
150+
for i in range(self._grad_accum_size):
151+
output, label, loss = self._schedule.forward_backward_step(
152+
data_iter, self._model, self._criterion, self._optimizer,
153+
forward_only=False,
154+
grad_accum_size=self._grad_accum_size,
155+
return_loss=return_loss)
156+
157+
if i == self._grad_accum_size - 1:
158+
# all reduce gradients
159+
self.handle_gradient()
160+
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
161+
else:
162+
output, label, loss = self._schedule.forward_backward_step(
163+
data_iter, self._model, self._criterion, self._optimizer,
164+
forward_only=True,
165+
grad_accum_size=1,
166+
return_loss=return_loss)
167+
168+
# consume the remaining dataset left out due to gradient accumulation
140169
if is_last_iteration:
141170
while True:
142171
try:
143-
trash = next(data_iter)
172+
_ = next(data_iter)
144173
except StopIteration:
145174
break
146175

colossalai/engine/schedule/_no_pipeline.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,23 @@
44
try:
55
import apex.amp as apex_amp
66
except:
7-
print('apex is required for mixed precision training')
7+
pass
8+
89
try:
910
import torch.cuda.amp as torch_amp
1011
except:
11-
print('PyTorch amp is not supported with the current PyTorch version')
12+
pass
13+
14+
from typing import Iterable
15+
16+
import torch.nn as nn
17+
from torch.optim import Optimizer
1218

1319
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
1420
ZeroRedundancyOptimizer_Level_3)
1521
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
16-
from ._utils import convert_to_fp16, convert_to_fp32
1722
from ._base_schedule import BaseSchedule
23+
from ._utils import convert_to_fp16, convert_to_fp32
1824
from ..amp import AMP_TYPE, GradScaler
1925

2026

@@ -73,7 +79,7 @@ def __init__(
7379
self.fp16 = False
7480
self.amp_type = None
7581

76-
def initialize(self, model, optimizer):
82+
def initialize(self, model: nn.Module, optimizer: Optimizer):
7783
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
7884
ZeroRedundancyOptimizer_Level_3)):
7985
self.use_zero_level_2_3 = True
@@ -89,16 +95,30 @@ def initialize(self, model, optimizer):
8995
return model, optimizer
9096

9197
def forward_backward_step(self,
92-
data_iter,
93-
model,
94-
criterion,
95-
optimizer=None,
96-
forward_only=False,
98+
data_iter: Iterable,
99+
model: nn.Module,
100+
criterion: nn.modules.loss._Loss,
101+
optimizer: Optimizer = None,
102+
forward_only: bool = False,
97103
grad_accum_size: int = 1,
98-
return_loss=True):
104+
return_loss: bool = True):
99105
"""The process function that loads loads a batch of dataset and feeds it to the model.
100106
The returned labels and loss will None if :attr:`return_loss` is False.
101107
108+
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
109+
:param model: Model for training and inference
110+
:param criterion: Loss function for training
111+
:param optimizer: Optimizer used for training
112+
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
113+
:param grad_accum_size: The number of iterations for gradient accumulation
114+
:param return_loss: Loss will be returned if True
115+
:type data_iter: Iterator
116+
:type model: torch.nn.Module
117+
:type criterion: torch.nn.modules.loss._Loss
118+
:type optimizer: torch.optim.Optimizer
119+
:type forward_only: bool, optional
120+
:type grad_accum_size: int
121+
:type return_loss: bool, optional
102122
:return: (output, label, loss)
103123
"""
104124
assert forward_only or return_loss, \
@@ -154,7 +174,7 @@ def forward_backward_step(self,
154174
else:
155175
return output, None, None
156176

157-
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
177+
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
158178
# step optimizer
159179
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
160180
if grad_clipping > 0.0:

0 commit comments

Comments
 (0)