Skip to content

Commit 7a44322

Browse files
authored
Merge pull request #21 from 1SAA/hotfix/trainer
Changed API in Schedule, Engine
2 parents af88570 + f58b744 commit 7a44322

File tree

4 files changed

+197
-288
lines changed

4 files changed

+197
-288
lines changed

colossalai/engine/_base_engine.py

Lines changed: 61 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -12,55 +12,40 @@
1212
from torch.nn import Module
1313
from torch.nn.modules.loss import _Loss
1414
from torch.optim import Optimizer
15-
from torch.optim.lr_scheduler import _LRScheduler
16-
from torch.utils.data import DataLoader
1715

1816
from .schedule import BaseSchedule, NoPipelineSchedule
1917

2018

2119
class Engine:
2220
"""Basic engine class for training and evaluation. It runs a specific process method
2321
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
22+
It controls a iteration in training.
2423
25-
:param train_dataloader: Dataloader in training
26-
:param test_dataloader: Dataloader in evaluation
2724
:param model: The neural network model
28-
:param criterion: Criterion for calculating loss
2925
:param optimizer: Optimizer for updating the parameters
30-
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
31-
:param schedule: Running schedule in :meth:`step`
32-
:type train_dataloader: DataLoader, optional
33-
:type test_dataloader: DataLoader, optional
26+
:param step_schedule: Running schedule in :meth:`step`
27+
:param gradient_accumulation: Steps of gradient accumulation
28+
:param gradient_clipping: The norm of gradient clipping
3429
:type model: Module
35-
:type criterion: _Loss, optional
36-
:type optimizer: Optimizer, optional
37-
:type lr_scheduler: _LRScheduler, optional
38-
:type schedule: BaseSchedule, optional
30+
:type optimizer: Optimizer
31+
:type step_schedule: BaseSchedule, optional
32+
:type gradient_accumulation: int, optional
33+
:type gradient_clipping: float, optional
3934
"""
35+
4036
def __init__(self,
41-
train_dataloader: Optional[DataLoader] = None,
42-
test_dataloader: Optional[DataLoader] = None,
43-
model: Module = None,
44-
criterion: _Loss = None,
45-
optimizer: Optimizer = None,
46-
lr_scheduler: Optional[_LRScheduler] = None,
47-
schedule: BaseSchedule = None,
37+
model: Module,
38+
optimizer: Optimizer,
39+
step_schedule: BaseSchedule = None,
4840
gradient_accumulation: int = 1,
49-
lr_scheduler_step: str = 'epoch'):
50-
self.train_dataloader = train_dataloader
51-
self.test_dataloader = test_dataloader
52-
assert model is not None, "Engine requires a model"
53-
self.model = model
54-
self.criterion = criterion
55-
self.optimizer = optimizer
56-
self.lr_scheduler = lr_scheduler
57-
self.schedule = schedule if schedule is not None \
41+
gradient_clipping: float = 0.0):
42+
self.schedule = step_schedule if step_schedule is not None \
5843
else NoPipelineSchedule()
44+
self.schedule.initialize(model, optimizer)
5945
self.grad_accum_size = gradient_accumulation
60-
self.grad_accum_step = 0
61-
self.lr_step = 0 # for epoch updating
62-
if lr_scheduler_step != 'epoch':
63-
self.lr_step = 1
46+
self.grad_accum_cur_step = 0
47+
self.grad_clip = gradient_clipping
48+
self.training = True # default
6449
self._logger = get_global_dist_logger()
6550

6651
# build gradient handler
@@ -72,8 +57,8 @@ def __init__(self,
7257
f'argument gradient_handler_cfg expected type list, ' \
7358
f'but got type {type(gpc.config.gradient_handler)}'
7459
gradient_handler_cfg = gpc.config.gradient_handler
75-
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
76-
ZeroRedundancyOptimizer_Level_3)):
60+
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
61+
ZeroRedundancyOptimizer_Level_3)):
7762
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
7863
self._logger.info(
7964
"Training with zero is detected, ZeROGradientHandler is automatically "
@@ -92,106 +77,71 @@ def __init__(self,
9277
"to all-reduce the gradients after a training step.",
9378
ranks=[0])
9479
for cfg in gradient_handler_cfg:
95-
handler = build_gradient_handler(cfg, self.model, self.optimizer)
80+
handler = build_gradient_handler(cfg, model, optimizer)
9681
self._gradient_handlers.append(handler)
9782

98-
self.schedule.initialize(self.train_dataloader, self.model,
99-
self.criterion, self.optimizer)
100-
self.schedule.grad_accum = self.grad_accum_size
101-
# add for robustness
102-
if self.optimizer is None:
103-
self.forward_only = True
104-
else:
105-
self.forward_only = False
106-
10783
def handle_gradient(self):
10884
"""Handles all-reduce operations of gradients across different parallel groups.
10985
"""
11086
for handler in self._gradient_handlers:
11187
handler.handle_gradient()
11288

113-
def set_dataloader(self, data: DataLoader, train: bool = True):
114-
"""Sets dataloader in training or evaluation.
115-
116-
:param data: Dataloader to be set
117-
:param train: Set training dataloader if True, otherwise evaluation dataloader
118-
:type data: DataLoader
119-
:type train: bool
120-
"""
121-
if train:
122-
self.train_dataloader = data
123-
else:
124-
self.test_dataloader = data
125-
126-
def get_model(self):
127-
"""Returns the neural network model in the engine.
128-
"""
129-
return self.model
130-
131-
def get_optimizer(self):
132-
"""Returns optimizier in the engine.
133-
"""
134-
return self.optimizer
135-
136-
def get_lr_scheduler(self):
137-
"""Returns the learning rate scheduler in the engine.
138-
"""
139-
return self.lr_scheduler
140-
14189
def train(self):
14290
"""Sets the model to training mode.
14391
"""
144-
self.forward_only = False
145-
self.schedule.train(dataloader=self.train_dataloader, mode=True)
92+
self.training = True
14693

14794
def eval(self):
14895
"""Sets the model to evaluation mode.
14996
"""
150-
self.forward_only = True
151-
self.schedule.train(dataloader=self.test_dataloader, mode=False)
152-
153-
def is_train(self):
154-
"""Returns True if it is in training, otherwise False.
155-
"""
156-
return not self.forward_only
157-
158-
def get_lr(self):
159-
"""Gets current learning rate.
160-
"""
161-
if self.lr_scheduler is not None:
162-
return self.lr_scheduler.get_lr()[0]
163-
else:
164-
return self.optimizer.param_groups[0]['lr']
165-
166-
def step(self, return_loss=True):
97+
self.training = False
98+
99+
def step(self,
100+
data_iter,
101+
model: Module,
102+
criterion: _Loss,
103+
optimizer: Optimizer = None,
104+
is_last_iteration: bool = False,
105+
return_loss=True):
167106
"""A running step based on the schedule. Usually, it runs a training or
168107
evaluation over a batch of dataset.
169108
109+
:param data_iter: Data iterator of the dataset
110+
:param model: The neural network model
111+
:param criterion: Loss function used to calculate
112+
:param optimizer: Optimizer for updating the parameters
113+
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
170114
:param return_loss: loss will be returned if True
171-
:type return_loss: bool
115+
:type data_iter: Iterator
116+
:type model: Module
117+
:type criterion: _Loss
118+
:type optimizer: Optimizer, optional
119+
:type is_last_iteration: bool, optional
120+
:type return_loss: bool, optional
172121
:return: (output, lablel, loss)
173122
"""
174-
if not self.forward_only and self.grad_accum_step == 0:
175-
self.schedule.zero_grad()
123+
if self.training and self.grad_accum_cur_step == 0:
124+
optimizer.zero_grad()
176125

177126
output, label, loss = self.schedule.forward_backward_step(
178-
forward_only=self.forward_only, return_loss=return_loss)
179-
180-
if not self.forward_only:
181-
self.grad_accum_step += 1
182-
if self.grad_accum_step == self.grad_accum_size:
127+
data_iter, model, criterion, optimizer,
128+
forward_only=not self.training,
129+
grad_accum_size=self.grad_accum_size,
130+
return_loss=return_loss)
131+
132+
if self.training:
133+
self.grad_accum_cur_step += 1
134+
if self.grad_accum_cur_step == self.grad_accum_size:
183135
# all reduce gradients
184136
self.handle_gradient()
185-
self.schedule.step()
186-
if self.lr_scheduler is not None and self.lr_step:
187-
self.lr_scheduler.step()
188-
self.grad_accum_step = 0
137+
self.schedule.optimizer_step(model, optimizer, self.grad_clip)
138+
self.grad_accum_cur_step = 0
189139

190-
return output, label, loss
140+
if is_last_iteration:
141+
while True:
142+
try:
143+
trash = next(data_iter)
144+
except StopIteration:
145+
break
191146

192-
def complete(self):
193-
"""Updating after a epoch.
194-
"""
195-
self.schedule.consume_batch()
196-
if self.lr_scheduler is not None and self.lr_step == 0:
197-
self.lr_scheduler.step()
147+
return output, label, loss

0 commit comments

Comments
 (0)