11#!/usr/bin/env python
22# -*- encoding: utf-8 -*-
33
4- from typing import Optional
4+ from torch .nn import Module
5+ from torch .nn .modules .loss import _Loss
6+ from torch .optim import Optimizer
57
68from colossalai .builder import build_gradient_handler
79from colossalai .context import ParallelMode
810from colossalai .core import global_context as gpc
911from colossalai .logging import get_global_dist_logger
1012from colossalai .nn import (ZeroRedundancyOptimizer_Level_2 ,
1113 ZeroRedundancyOptimizer_Level_3 )
12- from torch .nn import Module
13- from torch .nn .modules .loss import _Loss
14- from torch .optim import Optimizer
15-
16- from .schedule import BaseSchedule , NoPipelineSchedule
14+ from .schedule import BaseSchedule
1715
1816
1917class Engine :
@@ -36,49 +34,80 @@ class Engine:
3634 def __init__ (self ,
3735 model : Module ,
3836 optimizer : Optimizer ,
39- step_schedule : BaseSchedule = None ,
37+ criterion : _Loss ,
38+ step_schedule : BaseSchedule ,
39+ gradient_handlers : list = None ,
4040 gradient_accumulation : int = 1 ,
41- gradient_clipping : float = 0.0 ):
42- self .schedule = step_schedule if step_schedule is not None \
43- else NoPipelineSchedule ()
44- self .schedule .initialize (model , optimizer )
45- self .grad_accum_size = gradient_accumulation
46- self .grad_accum_cur_step = 0
47- self .grad_clip = gradient_clipping
41+ gradient_clipping : float = 0.0 ,
42+ ):
43+ self ._model = model
44+ self ._optimizer = optimizer
45+ self ._criterion = criterion
46+ self ._schedule = step_schedule
47+
48+ # schedule initialize
49+ self ._schedule .initialize (model , optimizer )
50+
51+ # state
4852 self .training = True # default
53+
54+ # gradient accumulation
55+ assert gradient_accumulation > 0 , 'gradient accumulation size must be larger than 0'
56+ self ._grad_accum_size = gradient_accumulation
57+ self ._grad_clip = gradient_clipping
4958 self ._logger = get_global_dist_logger ()
5059
5160 # build gradient handler
5261 self ._gradient_handlers = []
53- gradient_handler_cfg = []
5462
55- if hasattr ( gpc . config , 'gradient_handler' ) :
56- assert isinstance (gpc . config . gradient_handler , list ), \
63+ if gradient_handlers is not None :
64+ assert isinstance (gradient_handlers , list ), \
5765 f'argument gradient_handler_cfg expected type list, ' \
58- f'but got type { type (gpc .config .gradient_handler )} '
59- gradient_handler_cfg = gpc .config .gradient_handler
66+ f'but got type { type (gradient_handlers )} '
6067 elif isinstance (optimizer , (ZeroRedundancyOptimizer_Level_2 ,
6168 ZeroRedundancyOptimizer_Level_3 )):
62- gradient_handler_cfg = [dict (type = 'ZeROGradientHandler' )]
69+ gradient_handlers = [dict (type = 'ZeROGradientHandler' )]
6370 self ._logger .info (
6471 "Training with zero is detected, ZeROGradientHandler is automatically "
6572 "added even though not specified in the configuration" ,
6673 ranks = [0 ])
6774 elif gpc .is_initialized (ParallelMode .DATA ) and gpc .get_world_size (
6875 ParallelMode .DATA ) > 1 :
69- gradient_handler_cfg = [dict (type = 'DataParallelGradientHandler' )]
76+ gradient_handlers = [dict (type = 'DataParallelGradientHandler' )]
7077 self ._logger .info (
7178 "Data parallel training is detected, DataParallelGradientHandler is automatically "
7279 "added even though not specified in the configuration" ,
7380 ranks = [0 ])
74- if len (gradient_handler_cfg ) == 0 :
81+
82+ if gradient_handlers is None :
7583 self ._logger .warning (
7684 "No gradient handler is set up, please make sure you do not need "
7785 "to all-reduce the gradients after a training step." ,
7886 ranks = [0 ])
79- for cfg in gradient_handler_cfg :
80- handler = build_gradient_handler (cfg , model , optimizer )
81- self ._gradient_handlers .append (handler )
87+ else :
88+ for cfg in gradient_handlers :
89+ handler = build_gradient_handler (cfg , model , optimizer )
90+ self ._gradient_handlers .append (handler )
91+
92+ @property
93+ def model (self ):
94+ return self ._model
95+
96+ @property
97+ def optimizer (self ):
98+ return self ._optimizer
99+
100+ @property
101+ def criterion (self ):
102+ return self ._criterion
103+
104+ @property
105+ def schedule (self ):
106+ return self ._schedule
107+
108+ @property
109+ def gradient_accumulation (self ):
110+ return self ._grad_accum_size
82111
83112 def handle_gradient (self ):
84113 """Handles all-reduce operations of gradients across different parallel groups.
@@ -90,57 +119,57 @@ def train(self):
90119 """Sets the model to training mode.
91120 """
92121 self .training = True
122+ self ._model .train ()
93123
94124 def eval (self ):
95125 """Sets the model to evaluation mode.
96126 """
97127 self .training = False
128+ self ._model .eval ()
98129
99130 def step (self ,
100131 data_iter ,
101- model : Module ,
102- criterion : _Loss ,
103- optimizer : Optimizer = None ,
104132 is_last_iteration : bool = False ,
105133 return_loss = True ):
106134 """A running step based on the schedule. Usually, it runs a training or
107135 evaluation over a batch of dataset.
108136
109137 :param data_iter: Data iterator of the dataset
110- :param model: The neural network model
111- :param criterion: Loss function used to calculate
112- :param optimizer: Optimizer for updating the parameters
113138 :param is_last_iteration: If True, this iteration is the last iteration in the epoch
114139 :param return_loss: loss will be returned if True
115140 :type data_iter: Iterator
116- :type model: Module
117- :type criterion: _Loss
118- :type optimizer: Optimizer, optional
119141 :type is_last_iteration: bool, optional
120142 :type return_loss: bool, optional
121143 :return: (output, lablel, loss)
122144 """
123- if self .training and self .grad_accum_cur_step == 0 :
124- optimizer .zero_grad ()
125-
126- output , label , loss = self .schedule .forward_backward_step (
127- data_iter , model , criterion , optimizer ,
128- forward_only = not self .training ,
129- grad_accum_size = self .grad_accum_size ,
130- return_loss = return_loss )
131-
132145 if self .training :
133- self .grad_accum_cur_step += 1
134- if self .grad_accum_cur_step == self .grad_accum_size :
135- # all reduce gradients
136- self .handle_gradient ()
137- self .schedule .optimizer_step (model , optimizer , self .grad_clip )
138- self .grad_accum_cur_step = 0
146+ self ._optimizer .zero_grad ()
139147
148+ # differentiate training and eval with grad accum
149+ if self .training :
150+ for i in range (self ._grad_accum_size ):
151+ output , label , loss = self ._schedule .forward_backward_step (
152+ data_iter , self ._model , self ._criterion , self ._optimizer ,
153+ forward_only = False ,
154+ grad_accum_size = self ._grad_accum_size ,
155+ return_loss = return_loss )
156+
157+ if i == self ._grad_accum_size - 1 :
158+ # all reduce gradients
159+ self .handle_gradient ()
160+ self ._schedule .optimizer_step (self ._model , self ._optimizer , self ._grad_clip )
161+ else :
162+ output , label , loss = self ._schedule .forward_backward_step (
163+ data_iter , self ._model , self ._criterion , self ._optimizer ,
164+ forward_only = True ,
165+ grad_accum_size = 1 ,
166+ return_loss = return_loss )
167+
168+ # consume the remaining dataset left out due to gradient accumulation
140169 if is_last_iteration :
141170 while True :
142171 try :
143- trash = next (data_iter )
172+ _ = next (data_iter )
144173 except StopIteration :
145174 break
146175
0 commit comments