Skip to content

Commit 1f445bf

Browse files
authored
Support FP16 for more ops (#38123)
* support FP16 for more ops * add amp list tests * refine reduce_mean_grad * fix OP benchmark ci * fix fp16 reduce_mean * updat ut, but still have some problems * remove mean/reduce_mean fp16 kernel
1 parent f895560 commit 1f445bf

File tree

6 files changed

+122
-3
lines changed

6 files changed

+122
-3
lines changed

paddle/fluid/inference/tests/api/ipu_resnet50_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,4 @@ TEST(Analyzer_Resnet50_ipu, compare_results_2_batch) {
112112
}
113113

114114
} // namespace inference
115-
} // namespace paddle
115+
} // namespace paddle

paddle/fluid/operators/elementwise/elementwise_min_op.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,16 @@ namespace ops = paddle::operators;
4141

4242
REGISTER_OP_CUDA_KERNEL(
4343
elementwise_min,
44+
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext,
45+
paddle::platform::float16>,
4446
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, float>,
4547
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, double>,
4648
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int>,
4749
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int64_t>);
4850
REGISTER_OP_CUDA_KERNEL(
4951
elementwise_min_grad,
52+
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
53+
paddle::platform::float16>,
5054
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, float>,
5155
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, double>,
5256
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, int>,

paddle/fluid/operators/elementwise/elementwise_min_op.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
2020
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
2121
#include "paddle/fluid/platform/eigen_ext.h"
22+
#include "paddle/fluid/platform/float16.h"
2223

2324
namespace paddle {
2425
namespace operators {
@@ -67,6 +68,28 @@ struct MinGradDy {
6768
}
6869
};
6970

71+
#ifdef PADDLE_CUDA_FP16
72+
template <>
73+
struct MinGradDx<platform::float16> {
74+
HOSTDEVICE platform::float16 operator()(platform::float16 x,
75+
platform::float16 y,
76+
platform::float16 out,
77+
platform::float16 dout) const {
78+
return x < y ? dout : static_cast<platform::float16>(0);
79+
}
80+
};
81+
82+
template <>
83+
struct MinGradDy<platform::float16> {
84+
HOSTDEVICE platform::float16 operator()(platform::float16 x,
85+
platform::float16 y,
86+
platform::float16 out,
87+
platform::float16 dout) const {
88+
return x >= y ? dout : static_cast<platform::float16>(0);
89+
}
90+
};
91+
#endif
92+
7093
template <typename DeviceContext, typename T>
7194
class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
7295
public:

python/paddle/fluid/contrib/mixed_precision/fp16_lists.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
1919

20+
# lookup_table fp16 is slower than fp32, though fp16 is supported.
21+
_extra_unsupported_fp16_list = {'lookup_table', 'lookup_table_v2'}
22+
2023

2124
class AutoMixedPrecisionLists(object):
2225
"""
@@ -60,6 +63,8 @@ def _update_list(self):
6063
elif op_name in self.gray_list:
6164
self.gray_list.remove(op_name)
6265
self.white_list.add(op_name)
66+
if op_name in _extra_unsupported_fp16_list:
67+
self.unsupported_list.remove(op_name)
6368
if self._custom_black_list:
6469
for op_name in self._custom_black_list:
6570
if op_name in self.white_list:
@@ -170,7 +175,6 @@ def _update_list(self):
170175
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
171176
'GPU', core.VarDesc.VarType.FP16)
172177

173-
unsupported_fp16_list = {'lookup_table',
174-
'lookup_table_v2'} | _sys_unsupported_fp16_list
178+
unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
175179

176180
CustomOpLists = AutoMixedPrecisionLists
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import unittest
17+
from paddle.fluid.contrib.mixed_precision.fp16_lists import AutoMixedPrecisionLists
18+
19+
20+
class TestAMPList(unittest.TestCase):
21+
def test_main(self):
22+
custom_white_list = [
23+
'lookup_table',
24+
'lookup_table_v2',
25+
]
26+
amp_list = AutoMixedPrecisionLists(custom_white_list=custom_white_list)
27+
for op in custom_white_list:
28+
self.assertTrue(op in amp_list.white_list)
29+
self.assertTrue(op not in amp_list.black_list)
30+
self.assertTrue(op not in amp_list.unsupported_list)
31+
32+
33+
if __name__ == "__main__":
34+
unittest.main()

python/paddle/fluid/tests/unittests/test_elementwise_min_op.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
import unittest
1818
import numpy as np
1919
from op_test import OpTest, skip_check_grad_ci
20+
import paddle
21+
import paddle.fluid as fluid
22+
import paddle.fluid.core as core
23+
24+
paddle.enable_static()
2025

2126

2227
class TestElementwiseOp(OpTest):
@@ -142,5 +147,54 @@ def setUp(self):
142147
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
143148

144149

150+
class TestElementwiseMinOpFP16(unittest.TestCase):
151+
def get_out_and_grad(self, x_np, y_np, axis, place, use_fp32=False):
152+
assert x_np.dtype == np.float16
153+
assert y_np.dtype == np.float16
154+
if use_fp32:
155+
x_np = x_np.astype(np.float32)
156+
y_np = y_np.astype(np.float32)
157+
dtype = np.float16
158+
159+
with fluid.dygraph.guard(place):
160+
x = paddle.to_tensor(x_np)
161+
y = paddle.to_tensor(y_np)
162+
x.stop_gradient = False
163+
y.stop_gradient = False
164+
z = fluid.layers.elementwise_min(x, y, axis)
165+
x_g, y_g = paddle.grad([z], [x, y])
166+
return z.numpy().astype(dtype), x_g.numpy().astype(
167+
dtype), y_g.numpy().astype(dtype)
168+
169+
def check_main(self, x_shape, y_shape, axis=-1):
170+
if not paddle.is_compiled_with_cuda():
171+
return
172+
place = paddle.CUDAPlace(0)
173+
if not core.is_float16_supported(place):
174+
return
175+
176+
x_np = np.random.random(size=x_shape).astype(np.float16)
177+
y_np = np.random.random(size=y_shape).astype(np.float16)
178+
179+
z_1, x_g_1, y_g_1 = self.get_out_and_grad(x_np, y_np, axis, place,
180+
False)
181+
z_2, x_g_2, y_g_2 = self.get_out_and_grad(x_np, y_np, axis, place, True)
182+
self.assertTrue(np.array_equal(z_1, z_2), "{} vs {}".format(z_1, z_2))
183+
self.assertTrue(
184+
np.array_equal(x_g_1, x_g_2), "{} vs {}".format(x_g_1, x_g_2))
185+
self.assertTrue(
186+
np.array_equal(y_g_1, y_g_2), "{} vs {}".format(y_g_1, y_g_2))
187+
188+
def test_main(self):
189+
self.check_main((13, 17), (13, 17))
190+
self.check_main((10, 3, 4), (1, ))
191+
self.check_main((100, ), (100, ))
192+
self.check_main((100, 3, 2), (100, ), 0)
193+
self.check_main((2, 100, 3), (100, ), 1)
194+
self.check_main((2, 3, 100), (100, ))
195+
self.check_main((2, 25, 4, 1), (25, 4), 1)
196+
self.check_main((2, 10, 2, 5), (2, 10, 1, 5))
197+
198+
145199
if __name__ == '__main__':
146200
unittest.main()

0 commit comments

Comments
 (0)