File tree Expand file tree Collapse file tree 1 file changed +4
-18
lines changed
PyTorch/LanguageModeling/BERT Expand file tree Collapse file tree 1 file changed +4
-18
lines changed Original file line number Diff line number Diff line change @@ -119,21 +119,7 @@ def step(self, closure=None):
119119 loss = None
120120 if closure is not None :
121121 loss = closure ()
122-
123- grad_list = []
124- for group in self .param_groups :
125- for p in group ['params' ]:
126- if p .grad is None :
127- continue
128- grad_list .append (p )
129-
130- dummy_overflow_buf = torch .cuda .IntTensor ([0 ])
131- global_grad_norm = multi_tensor_applier (
132- multi_tensor_l2norm ,
133- dummy_overflow_buf ,
134- [grad_list ],
135- False )[0 ].item ()
136-
122+
137123 for group in self .param_groups :
138124 for p in group ['params' ]:
139125 if p .grad is None :
@@ -156,8 +142,8 @@ def step(self, closure=None):
156142 beta1 , beta2 = group ['b1' ], group ['b2' ]
157143
158144 # Add grad clipping
159- if global_grad_norm > group ['max_grad_norm' ]:
160- p = p * group ['max_grad_norm' ] / global_grad_norm
145+ if group ['max_grad_norm' ] > 0 :
146+ clip_grad_norm_ ( p , group ['max_grad_norm' ])
161147
162148 # Decay the first and second moment running average coefficient
163149 # In-place operations to update the averages at the same time
@@ -186,4 +172,4 @@ def step(self, closure=None):
186172
187173 state ['step' ] += 1
188174
189- return loss
175+ return loss
You can’t perform that action at this time.
0 commit comments