Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .

```python
import colossalai
from colossalai.engine import Engine
from colossalai.trainer import Trainer
from colossalai.core import global_context as gpc

model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
engine, train_dataloader, test_dataloader = colossalai.initialize()

trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=5
)
Expand Down
10 changes: 9 additions & 1 deletion colossalai/builder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
from .builder import *
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
from .pipeline import ModelInitializer

__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'ModelInitializer'
]
33 changes: 14 additions & 19 deletions colossalai/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,6 @@ def build_transform(config):
return build_from_registry(config, TRANSFORMS)


def build_pipe_alloc_policy(config):
"""Returns a pipeline allocation policy object constructed from `config`.

:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: A pipeline allocation policy object
:rtype:
"""
return build_from_registry(config, PIPE_ALLOC_POLICY)


def build_data_sampler(config, dataset):
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
constructed from `config`.
Expand Down Expand Up @@ -235,7 +223,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)


def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.

Expand All @@ -254,9 +242,16 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
"""
config_ = config.copy()
mod_type = config_.pop('type')
# warmup epochs will overwrite warmup steps
if 'warmup_epochs' in config_:
warmup_epochs = config_.pop('warmup_epochs')
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
**config_)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)


def build_schedule(config):
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.

:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule`
:rtype: :class:`colossalai.engine.schedule.BaseSchedule`
"""
return build_from_registry(config, SCHEDULE)
2 changes: 1 addition & 1 deletion colossalai/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .amp_type import AMP_TYPE
from ._base_engine import Engine
from .gradient_handler import *
from .schedule import *
from .amp import *


__all__ = ['Engine']
216 changes: 111 additions & 105 deletions colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
@@ -1,170 +1,176 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Optional
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from .schedule import BaseSchedule, NoPipelineSchedule
from .schedule import BaseSchedule


class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.

:param train_dataloader: Dataloader in training
:param test_dataloader: Dataloader in evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
:param schedule: Running schedule in :meth:`step`
:type train_dataloader: DataLoader, optional
:type test_dataloader: DataLoader, optional
:param step_schedule: Running schedule in :meth:`step`
:param gradient_accumulation: Steps of gradient accumulation
:param gradient_clipping: The norm of gradient clipping
:type model: Module
:type criterion: _Loss, optional
:type optimizer: Optimizer, optional
:type lr_scheduler: _LRScheduler, optional
:type schedule: BaseSchedule, optional
:type optimizer: Optimizer
:type step_schedule: BaseSchedule, optional
:type gradient_accumulation: int, optional
:type gradient_clipping: float, optional
"""

def __init__(self,
train_dataloader: Optional[DataLoader] = None,
test_dataloader: Optional[DataLoader] = None,
model: Module = None,
criterion: _Loss = None,
optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model"
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \
else NoPipelineSchedule()
model: Module,
optimizer: Optimizer,
criterion: _Loss,
step_schedule: BaseSchedule,
gradient_handlers: list = None,
gradient_accumulation: int = 1,
gradient_clipping: float = 0.0,
):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
self._schedule = step_schedule

# schedule initialize
self._schedule.initialize(model, optimizer)

# state
self.training = True # default

# gradient accumulation
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
self._grad_accum_size = gradient_accumulation
self._grad_clip = gradient_clipping
self._logger = get_global_dist_logger()

# build gradient handler
self._gradient_handlers = []
gradient_handler_cfg = []

if hasattr(gpc.config, 'gradient_handler'):
assert isinstance(gpc.config.gradient_handler, list), \
if gradient_handlers is not None:
assert isinstance(gradient_handlers, list), \
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gpc.config.gradient_handler)}'
gradient_handler_cfg = gpc.config.gradient_handler
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
f'but got type {type(gradient_handlers)}'
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handlers = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
gradient_handlers = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
if len(gradient_handler_cfg) == 0:

if gradient_handlers is None:
self._logger.warning(
"No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
for cfg in gradient_handler_cfg:
handler = build_gradient_handler(cfg, self.model, self.optimizer)
self._gradient_handlers.append(handler)
else:
for cfg in gradient_handlers:
handler = build_gradient_handler(cfg, model, optimizer)
self._gradient_handlers.append(handler)

self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer,
self.lr_scheduler)
self.forward_only = False
@property
def model(self):
return self._model

def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
@property
def optimizer(self):
return self._optimizer

def set_dataloader(self, data: DataLoader, train: bool = True):
"""Sets dataloader in training or evaluation.
@property
def criterion(self):
return self._criterion

:param data: Dataloader to be set
:param train: Set training dataloader if True, otherwise evaluation dataloader
:type data: DataLoader
:type train: bool
"""
if train:
self.train_dataloader = data
else:
self.test_dataloader = data
@property
def schedule(self):
return self._schedule

def get_model(self):
"""Returns the neural network model in the engine.
"""
return self.model
def get_optimizer(self):
"""Returns optimizier in the engine.
"""
return self.optimizer
@property
def gradient_accumulation(self):
return self._grad_accum_size

def get_lr_scheduler(self):
"""Returns the learning rate scheduler in the engine.
def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
return self.lr_scheduler
for handler in self._gradient_handlers:
handler.handle_gradient()

def train(self):
"""Sets the model to training mode.
"""
self.forward_only = False
self.schedule.train(dataloader=self.train_dataloader, mode=True)
self.training = True
self._model.train()

def eval(self):
"""Sets the model to evaluation mode.
"""
self.forward_only = True
self.schedule.train(dataloader=self.test_dataloader, mode=False)
self.training = False
self._model.eval()

def is_train(self):
"""Returns True if it is in training, otherwise False.
"""
return not self.forward_only

def get_lr(self):
"""Gets current learning rate.
"""
return self.schedule.get_lr()

def step(self, return_loss=True):
def step(self,
data_iter,
is_last_iteration: bool = False,
return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.

:param data_iter: Data iterator of the dataset
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
:param return_loss: loss will be returned if True
:type return_loss: bool
:type data_iter: Iterator
:type is_last_iteration: bool, optional
:type return_loss: bool, optional
:return: (output, lablel, loss)
"""
self.schedule.zero_grad(forward_only=self.forward_only)

output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss)

if not self.forward_only:
# all reduce gradients
self.handle_gradient()

self.schedule.step()
if self.training:
self._optimizer.zero_grad()

# differentiate training and eval with grad accum
if self.training:
for i in range(self._grad_accum_size):
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)

if i == self._grad_accum_size - 1:
# all reduce gradients
self.handle_gradient()
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
else:
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=True,
grad_accum_size=1,
return_loss=return_loss)

# consume the remaining dataset left out due to gradient accumulation
if is_last_iteration:
while True:
try:
_ = next(data_iter)
except StopIteration:
break

return output, label, loss
2 changes: 2 additions & 0 deletions colossalai/engine/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .grad_scaler import GradScaler
from .amp_type import AMP_TYPE
File renamed without changes.
Loading