Skip to content

Commit 5bdca05

Browse files
authored
Support float16 when using ClipGradByGlobalNorm. (#33565)
This PR supports gradient clip (ClipGradByGlobalNorm) when training with AMP(auto mixed precision).
1 parent 11965bc commit 5bdca05

File tree

6 files changed

+105
-42
lines changed

6 files changed

+105
-42
lines changed

paddle/fluid/operators/squared_l2_norm_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ REGISTER_OPERATOR(squared_l2_norm, ops::SquaredL2NormOp,
9393
REGISTER_OPERATOR(squared_l2_norm_grad, ops::SquaredL2NormGradOp);
9494
REGISTER_OP_CPU_KERNEL(
9595
squared_l2_norm,
96-
ops::SquaredL2NormKernel<paddle::platform::CPUDeviceContext, float>);
96+
ops::SquaredL2NormKernel<paddle::platform::CPUDeviceContext, float>,
97+
ops::SquaredL2NormKernel<paddle::platform::CPUDeviceContext, double>);
9798
REGISTER_OP_CPU_KERNEL(
9899
squared_l2_norm_grad,
99-
ops::SquaredL2NormGradKernel<paddle::platform::CPUDeviceContext, float>);
100+
ops::SquaredL2NormGradKernel<paddle::platform::CPUDeviceContext, float>,
101+
ops::SquaredL2NormGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/squared_l2_norm_op.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ limitations under the License. */
1616
namespace ops = paddle::operators;
1717
REGISTER_OP_CUDA_KERNEL(
1818
squared_l2_norm,
19-
ops::SquaredL2NormKernel<paddle::platform::CUDADeviceContext, float>);
19+
ops::SquaredL2NormKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::SquaredL2NormKernel<paddle::platform::CUDADeviceContext, double>);
2021
REGISTER_OP_CUDA_KERNEL(
2122
squared_l2_norm_grad,
22-
ops::SquaredL2NormGradKernel<paddle::platform::CUDADeviceContext, float>);
23+
ops::SquaredL2NormGradKernel<paddle::platform::CUDADeviceContext, float>,
24+
ops::SquaredL2NormGradKernel<paddle::platform::CUDADeviceContext, double>);

python/paddle/fluid/clip.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _squared_l2_norm(x):
4040
This OP returns the squared L2 norm of a tensor.
4141
"""
4242

43-
if core.is_compiled_with_xpu():
43+
if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16:
4444
square = layers.square(x)
4545
sum_square = layers.reduce_sum(square)
4646
return sum_square
@@ -49,7 +49,7 @@ def _squared_l2_norm(x):
4949
return core.ops.squared_l2_norm(x)
5050

5151
op_type = 'squared_l2_norm'
52-
check_variable_and_dtype(x, 'x', ['float32'], op_type)
52+
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
5353
helper = LayerHelper(op_type, **locals())
5454
out = helper.create_variable_for_type_inference(x.dtype)
5555

@@ -476,6 +476,8 @@ def _dygraph_clip(self, params_grads):
476476
def _static_clip(self, params_grads):
477477
params_and_grads = []
478478
sum_square_list = []
479+
sum_square_list_fp16 = []
480+
sum_square_list_fp32 = []
479481
with framework.name_scope('gradient_clip'):
480482
for p, g in params_grads:
481483
if g is None:
@@ -488,16 +490,39 @@ def _static_clip(self, params_grads):
488490
merge_grad = layers.merge_selected_rows(g)
489491
merge_grad = layers.get_tensor_from_selected_rows(
490492
merge_grad)
491-
492493
sum_square = _squared_l2_norm(merge_grad)
493-
sum_square_list.append(sum_square)
494+
if sum_square.dtype == core.VarDesc.VarType.FP16:
495+
sum_square_list_fp16.append(sum_square)
496+
elif sum_square.dtype == core.VarDesc.VarType.FP32:
497+
sum_square_list_fp32.append(sum_square)
498+
else:
499+
sum_square_list.append(sum_square)
494500

495501
# all parameters have been filterd out
496-
if len(sum_square_list) == 0:
502+
if len(sum_square_list) + len(sum_square_list_fp16) + len(
503+
sum_square_list_fp32) == 0:
497504
return params_grads
498505

499506
with p.block.program._optimized_guard([p, g]):
500-
global_norm_var = layers.sums(sum_square_list)
507+
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
508+
509+
global_norm_var = []
510+
if len(sum_square_list_fp16) > 0:
511+
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
512+
global_norm_var.append(
513+
global_norm_var_fp16.astype(sum_dtype))
514+
if len(sum_square_list_fp32) > 0:
515+
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
516+
if sum_dtype == 'float32':
517+
global_norm_var.append(global_norm_var_fp32)
518+
else:
519+
global_norm_var.append(
520+
global_norm_var_fp32.astype(sum_dtype))
521+
if len(sum_square_list) > 0:
522+
# fp64
523+
global_norm_var_other_dtype = layers.sums(sum_square_list)
524+
global_norm_var.append(global_norm_var_other_dtype)
525+
global_norm_var = layers.sums(global_norm_var)
501526
global_norm_var = layers.sqrt(x=global_norm_var)
502527
max_global_norm = layers.fill_constant(
503528
shape=[1],
@@ -507,7 +532,6 @@ def _static_clip(self, params_grads):
507532
x=max_global_norm,
508533
y=layers.elementwise_max(
509534
x=max_global_norm, y=global_norm_var))
510-
511535
param_new_grad_name_dict = dict()
512536
for p, g in params_grads:
513537
if g is None:
@@ -518,11 +542,15 @@ def _static_clip(self, params_grads):
518542

519543
with p.block.program._optimized_guard([p, g]):
520544
# inplace
545+
scale_input = (scale_var.astype('float16')
546+
if g.dtype == core.VarDesc.VarType.FP16 else
547+
scale_var)
521548
p.block.append_op(
522549
type='elementwise_mul',
523550
inputs={'X': g,
524-
'Y': scale_var},
551+
'Y': scale_input},
525552
outputs={'Out': g})
553+
526554
param_new_grad_name_dict[p.name] = g.name
527555
params_and_grads.append((p, g))
528556

python/paddle/fluid/layers/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,10 @@ def sums(input, out=None):
538538
if isinstance(input, list) or isinstance(input, tuple):
539539
for input_section in input:
540540
check_variable_and_dtype(input_section, "input", \
541-
['float32', 'float64', 'int32', 'int64'], 'sums')
541+
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums')
542542
else:
543543
check_variable_and_dtype(input, "input", \
544-
['float32', 'float64', 'int32', 'int64'], 'sums')
544+
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums')
545545

546546
helper = LayerHelper('sum', **locals())
547547
if out is None:

python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,10 @@ def test_sharding_gradient_clip(self):
266266
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
267267
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
268268
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
269-
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
270-
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
271-
'elementwise_mul', 'momentum', 'momentum', 'momentum'
269+
'c_allreduce_sum', 'sum', 'c_allreduce_sum', 'sqrt',
270+
'fill_constant', 'elementwise_max', 'elementwise_div',
271+
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum',
272+
'momentum', 'momentum'
272273
])
273274

274275
def test_sharding_clone_for_test(self):

python/paddle/fluid/tests/unittests/test_gradient_clip.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,18 @@ def get_places(self):
7171
def check_clip_result(self, out, out_clip):
7272
pass
7373

74-
def check_gradient_clip(self, place):
74+
def check_gradient_clip(self, place, dtype='float32'):
7575
prog = fluid.Program()
7676
startup_program = fluid.Program()
7777
with fluid.program_guard(
7878
main_program=prog, startup_program=startup_program):
7979
image = fluid.data(name="a", shape=[-1, 784], dtype='float32')
8080
label = fluid.data(name="b", shape=[-1, 1], dtype='int64')
81-
hidden = fluid.layers.fc(input=image, size=32, act='relu')
81+
if dtype != 'float32':
82+
image_cast = paddle.cast(image, dtype)
83+
hidden = fluid.layers.fc(input=image_cast, size=32, act='relu')
84+
else:
85+
hidden = fluid.layers.fc(input=image, size=32, act='relu')
8286
predict = fluid.layers.fc(input=hidden, size=10, act='softmax')
8387

8488
cost = fluid.layers.cross_entropy(input=predict, label=label)
@@ -176,6 +180,15 @@ def func(params_grads):
176180
self.clip_gradient = func
177181
self.check_gradient_clip(fluid.CPUPlace())
178182

183+
# test whether the ouput is right when use grad_clip under float64
184+
def test_new_gradient_clip_fp64(self):
185+
def func(params_grads):
186+
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
187+
return clip(params_grads)
188+
189+
self.clip_gradient = func
190+
self.check_gradient_clip(fluid.CPUPlace(), "float64")
191+
179192
# invoke 'set_gradient_clip' in a wrong order
180193
def test_wrong_API_order(self):
181194
def backward_func(cost):
@@ -192,36 +205,53 @@ def backward_func(cost):
192205
for place in self.get_places():
193206
self.check_sparse_gradient_clip(place)
194207

195-
# if grad is None or not need clip
196-
def test_none_grad(self):
197-
clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
198-
x = fluid.default_main_program().global_block().create_parameter(
199-
name="x", shape=[2, 3], dtype="float32")
200-
y = fluid.default_main_program().global_block().create_parameter(
201-
name="y", shape=[2, 3], dtype="float32")
202-
203-
# (x, None) should not be returned
204-
params_grads = [(x, None), (x, y), (y, x)]
205-
params_grads = clip(params_grads)
206-
self.assertTrue(
207-
len(params_grads) == 2,
208-
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
209-
)
210-
211-
ops = [op.type for op in x.block.ops]
212-
self.assertListEqual(ops, [
213-
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
214-
'fill_constant', 'elementwise_max', 'elementwise_div',
215-
'elementwise_mul', 'elementwise_mul'
216-
])
217-
218208
# raise typeError
219209
def test_tpyeError(self):
220210
# the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class
221211
with self.assertRaises(TypeError):
222212
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1,
223213
grad_clip="test")
224214

215+
# if grad is None or not need clip
216+
def test_none_grad_fp32(self):
217+
ops = self._test_none_grad_helper("float32")
218+
self.assertListEqual(ops, [
219+
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sum', 'sqrt',
220+
'fill_constant', 'elementwise_max', 'elementwise_div',
221+
'elementwise_mul', 'elementwise_mul'
222+
])
223+
224+
def test_none_grad_fp16(self):
225+
ops = self._test_none_grad_helper("float16")
226+
self.assertListEqual(ops, [
227+
'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast',
228+
'sum', 'sqrt', 'fill_constant', 'elementwise_max',
229+
'elementwise_div', 'cast', 'elementwise_mul', 'cast',
230+
'elementwise_mul'
231+
])
232+
233+
def _test_none_grad_helper(self, dtype):
234+
prog = fluid.Program()
235+
startup_program = fluid.Program()
236+
with fluid.program_guard(
237+
main_program=prog, startup_program=startup_program):
238+
clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
239+
x = fluid.default_main_program().global_block().create_parameter(
240+
name="x", shape=[2, 3], dtype=dtype)
241+
y = fluid.default_main_program().global_block().create_parameter(
242+
name="y", shape=[2, 3], dtype=dtype)
243+
244+
# (x, None) should not be returned
245+
params_grads = [(x, None), (x, y), (y, x)]
246+
params_grads = clip(params_grads)
247+
self.assertTrue(
248+
len(params_grads) == 2,
249+
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
250+
)
251+
252+
ops = [op.type for op in x.block.ops]
253+
return ops
254+
225255

226256
class TestGradientClipByNorm(TestGradientClip):
227257
def init(self):

0 commit comments

Comments
 (0)