1212from torch .nn import Module
1313from torch .nn .modules .loss import _Loss
1414from torch .optim import Optimizer
15- from torch .optim .lr_scheduler import _LRScheduler
16- from torch .utils .data import DataLoader
1715
1816from .schedule import BaseSchedule , NoPipelineSchedule
1917
2018
2119class Engine :
2220 """Basic engine class for training and evaluation. It runs a specific process method
2321 :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
22+ It controls a iteration in training.
2423
25- :param train_dataloader: Dataloader in training
26- :param test_dataloader: Dataloader in evaluation
2724 :param model: The neural network model
28- :param criterion: Criterion for calculating loss
2925 :param optimizer: Optimizer for updating the parameters
30- :param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
31- :param schedule: Running schedule in :meth:`step`
32- :type train_dataloader: DataLoader, optional
33- :type test_dataloader: DataLoader, optional
26+ :param step_schedule: Running schedule in :meth:`step`
27+ :param gradient_accumulation: Steps of gradient accumulation
28+ :param gradient_clipping: The norm of gradient clipping
3429 :type model: Module
35- :type criterion: _Loss, optional
36- :type optimizer: Optimizer , optional
37- :type lr_scheduler: _LRScheduler , optional
38- :type schedule: BaseSchedule , optional
30+ :type optimizer: Optimizer
31+ :type step_schedule: BaseSchedule , optional
32+ :type gradient_accumulation: int , optional
33+ :type gradient_clipping: float , optional
3934 """
35+
4036 def __init__ (self ,
41- train_dataloader : Optional [DataLoader ] = None ,
42- test_dataloader : Optional [DataLoader ] = None ,
43- model : Module = None ,
44- criterion : _Loss = None ,
45- optimizer : Optimizer = None ,
46- lr_scheduler : Optional [_LRScheduler ] = None ,
47- schedule : BaseSchedule = None ,
37+ model : Module ,
38+ optimizer : Optimizer ,
39+ step_schedule : BaseSchedule = None ,
4840 gradient_accumulation : int = 1 ,
49- lr_scheduler_step : str = 'epoch' ):
50- self .train_dataloader = train_dataloader
51- self .test_dataloader = test_dataloader
52- assert model is not None , "Engine requires a model"
53- self .model = model
54- self .criterion = criterion
55- self .optimizer = optimizer
56- self .lr_scheduler = lr_scheduler
57- self .schedule = schedule if schedule is not None \
41+ gradient_clipping : float = 0.0 ):
42+ self .schedule = step_schedule if step_schedule is not None \
5843 else NoPipelineSchedule ()
44+ self .schedule .initialize (model , optimizer )
5945 self .grad_accum_size = gradient_accumulation
60- self .grad_accum_step = 0
61- self .lr_step = 0 # for epoch updating
62- if lr_scheduler_step != 'epoch' :
63- self .lr_step = 1
46+ self .grad_accum_cur_step = 0
47+ self .grad_clip = gradient_clipping
48+ self .training = True # default
6449 self ._logger = get_global_dist_logger ()
6550
6651 # build gradient handler
@@ -72,8 +57,8 @@ def __init__(self,
7257 f'argument gradient_handler_cfg expected type list, ' \
7358 f'but got type { type (gpc .config .gradient_handler )} '
7459 gradient_handler_cfg = gpc .config .gradient_handler
75- elif isinstance (self . optimizer , (ZeroRedundancyOptimizer_Level_2 ,
76- ZeroRedundancyOptimizer_Level_3 )):
60+ elif isinstance (optimizer , (ZeroRedundancyOptimizer_Level_2 ,
61+ ZeroRedundancyOptimizer_Level_3 )):
7762 gradient_handler_cfg = [dict (type = 'ZeROGradientHandler' )]
7863 self ._logger .info (
7964 "Training with zero is detected, ZeROGradientHandler is automatically "
@@ -92,106 +77,71 @@ def __init__(self,
9277 "to all-reduce the gradients after a training step." ,
9378 ranks = [0 ])
9479 for cfg in gradient_handler_cfg :
95- handler = build_gradient_handler (cfg , self . model , self . optimizer )
80+ handler = build_gradient_handler (cfg , model , optimizer )
9681 self ._gradient_handlers .append (handler )
9782
98- self .schedule .initialize (self .train_dataloader , self .model ,
99- self .criterion , self .optimizer )
100- self .schedule .grad_accum = self .grad_accum_size
101- # add for robustness
102- if self .optimizer is None :
103- self .forward_only = True
104- else :
105- self .forward_only = False
106-
10783 def handle_gradient (self ):
10884 """Handles all-reduce operations of gradients across different parallel groups.
10985 """
11086 for handler in self ._gradient_handlers :
11187 handler .handle_gradient ()
11288
113- def set_dataloader (self , data : DataLoader , train : bool = True ):
114- """Sets dataloader in training or evaluation.
115-
116- :param data: Dataloader to be set
117- :param train: Set training dataloader if True, otherwise evaluation dataloader
118- :type data: DataLoader
119- :type train: bool
120- """
121- if train :
122- self .train_dataloader = data
123- else :
124- self .test_dataloader = data
125-
126- def get_model (self ):
127- """Returns the neural network model in the engine.
128- """
129- return self .model
130-
131- def get_optimizer (self ):
132- """Returns optimizier in the engine.
133- """
134- return self .optimizer
135-
136- def get_lr_scheduler (self ):
137- """Returns the learning rate scheduler in the engine.
138- """
139- return self .lr_scheduler
140-
14189 def train (self ):
14290 """Sets the model to training mode.
14391 """
144- self .forward_only = False
145- self .schedule .train (dataloader = self .train_dataloader , mode = True )
92+ self .training = True
14693
14794 def eval (self ):
14895 """Sets the model to evaluation mode.
14996 """
150- self .forward_only = True
151- self .schedule .train (dataloader = self .test_dataloader , mode = False )
152-
153- def is_train (self ):
154- """Returns True if it is in training, otherwise False.
155- """
156- return not self .forward_only
157-
158- def get_lr (self ):
159- """Gets current learning rate.
160- """
161- if self .lr_scheduler is not None :
162- return self .lr_scheduler .get_lr ()[0 ]
163- else :
164- return self .optimizer .param_groups [0 ]['lr' ]
165-
166- def step (self , return_loss = True ):
97+ self .training = False
98+
99+ def step (self ,
100+ data_iter ,
101+ model : Module ,
102+ criterion : _Loss ,
103+ optimizer : Optimizer = None ,
104+ is_last_iteration : bool = False ,
105+ return_loss = True ):
167106 """A running step based on the schedule. Usually, it runs a training or
168107 evaluation over a batch of dataset.
169108
109+ :param data_iter: Data iterator of the dataset
110+ :param model: The neural network model
111+ :param criterion: Loss function used to calculate
112+ :param optimizer: Optimizer for updating the parameters
113+ :param is_last_iteration: If True, this iteration is the last iteration in the epoch
170114 :param return_loss: loss will be returned if True
171- :type return_loss: bool
115+ :type data_iter: Iterator
116+ :type model: Module
117+ :type criterion: _Loss
118+ :type optimizer: Optimizer, optional
119+ :type is_last_iteration: bool, optional
120+ :type return_loss: bool, optional
172121 :return: (output, lablel, loss)
173122 """
174- if not self .forward_only and self .grad_accum_step == 0 :
175- self . schedule .zero_grad ()
123+ if self .training and self .grad_accum_cur_step == 0 :
124+ optimizer .zero_grad ()
176125
177126 output , label , loss = self .schedule .forward_backward_step (
178- forward_only = self .forward_only , return_loss = return_loss )
179-
180- if not self .forward_only :
181- self .grad_accum_step += 1
182- if self .grad_accum_step == self .grad_accum_size :
127+ data_iter , model , criterion , optimizer ,
128+ forward_only = not self .training ,
129+ grad_accum_size = self .grad_accum_size ,
130+ return_loss = return_loss )
131+
132+ if self .training :
133+ self .grad_accum_cur_step += 1
134+ if self .grad_accum_cur_step == self .grad_accum_size :
183135 # all reduce gradients
184136 self .handle_gradient ()
185- self .schedule .step ()
186- if self .lr_scheduler is not None and self .lr_step :
187- self .lr_scheduler .step ()
188- self .grad_accum_step = 0
137+ self .schedule .optimizer_step (model , optimizer , self .grad_clip )
138+ self .grad_accum_cur_step = 0
189139
190- return output , label , loss
140+ if is_last_iteration :
141+ while True :
142+ try :
143+ trash = next (data_iter )
144+ except StopIteration :
145+ break
191146
192- def complete (self ):
193- """Updating after a epoch.
194- """
195- self .schedule .consume_batch ()
196- if self .lr_scheduler is not None and self .lr_step == 0 :
197- self .lr_scheduler .step ()
147+ return output , label , loss
0 commit comments