1212
1313from colossalai .context import ParallelMode
1414from colossalai .core import global_context as gpc
15- from colossalai .engine .amp_type import AMP_TYPE
1615from colossalai .nn import (ZeroRedundancyOptimizer_Level_2 ,
1716 ZeroRedundancyOptimizer_Level_3 )
17+ from colossalai .nn .optimizer ._utils import clip_grad_norm_fp32
1818from ._utils import convert_to_fp16
1919from ._base_schedule import BaseSchedule
20+ from ..amp import AMP_TYPE , GradScaler
2021
2122
2223class NoPipelineSchedule (BaseSchedule ):
@@ -30,6 +31,7 @@ class NoPipelineSchedule(BaseSchedule):
3031 :type amp_type: AMP_TYPE
3132 :type amp_config: dict
3233 """
34+
3335 def __init__ (
3436 self ,
3537 amp_type : AMP_TYPE = None ,
@@ -101,7 +103,7 @@ def initialize(self,
101103
102104 if self .fp16 :
103105 if self .amp_type == AMP_TYPE .TORCH :
104- self ._torch_amp_scaler = torch_amp . GradScaler (** self .amp_cfg )
106+ self ._torch_amp_scaler = GradScaler (** self .amp_cfg )
105107 elif self .amp_type == AMP_TYPE .APEX :
106108 self .model , self .optimizer = apex_amp .initialize (
107109 self .model , self .optimizer , ** self .amp_cfg )
@@ -175,9 +177,16 @@ def forward_backward_step(self, forward_only=False, return_loss=True):
175177 def step (self ):
176178 # step optimizer
177179 if self .fp16 and self .amp_type == AMP_TYPE .TORCH :
180+ if getattr (gpc .config , 'clip_grad' , 0.0 ) > 0.0 :
181+ self ._torch_amp_scaler .unscale_ (self .optimizer )
182+ clip_grad_norm_fp32 (self .model .parameters (),
183+ gpc .config .clip_grad )
178184 self ._torch_amp_scaler .step (self .optimizer )
179185 self ._torch_amp_scaler .update ()
180186 else :
187+ if not self .fp16 and not self .use_zero_level_2_3 and getattr (gpc .config , 'clip_grad' , 0.0 ) > 0.0 :
188+ clip_grad_norm_fp32 (self .model .parameters (),
189+ gpc .config .clip_grad )
181190 self .optimizer .step ()
182191
183192 # update lr scheduler
0 commit comments