Skip to content

Commit 0c43ce2

Browse files
authored
Update BF16 amp list (#39304)
* amp list updated * tests updated * gray list updated * amp list updated * test updated
1 parent ebd1474 commit 0c43ce2

File tree

3 files changed

+40
-17
lines changed

3 files changed

+40
-17
lines changed

python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,18 @@ def _update_list(self):
8383
bf16_initializer_list = {'fill_constant', 'uniform_random'}
8484

8585
# always bf16
86-
bf16_list = {'elementwise_add', 'mul'}
86+
bf16_list = {
87+
'conv2d',
88+
'matmul',
89+
'matmul_v2',
90+
'mul',
91+
}
8792

8893
# depends on the prev_op type
8994
gray_list = {
90-
'cast',
91-
'fill_constant',
92-
'reduce_mean',
93-
'reshape2',
94-
'scale',
95+
'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div',
96+
'relu', 'layer_norm', 'slice', 'concat', 'uniform_random', 'reshape2',
97+
'transpose2', 'pool2d', 'sigmoid', 'cast', 'scale', 'fill_constant', 'split'
9598
}
9699

97100
_, _, _sys_unsupported_bf16_list = core.op_supported_infos(

python/paddle/fluid/contrib/tests/test_bf16_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,20 @@ def test_amp_lists_3(self):
5757
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'})
5858

5959
def test_amp_lists_4(self):
60-
# 4. w=None, b={'elementwise_add'}
61-
self.bf16_list.remove('elementwise_add')
62-
self.fp32_list.add('elementwise_add')
60+
# 4. w=None, b={'matmul_v2'}
61+
self.bf16_list.remove('matmul_v2')
62+
self.fp32_list.add('matmul_v2')
6363

6464
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
65-
custom_fp32_list={'elementwise_add'})
65+
custom_fp32_list={'matmul_v2'})
6666

6767
def test_amp_lists_5(self):
68-
# 5. w=None, b={'elementwise_add'}
69-
self.fp32_list.add('elementwise_add')
70-
self.bf16_list.remove('elementwise_add')
68+
# 5. w=None, b={'matmul_v2'}
69+
self.fp32_list.add('matmul_v2')
70+
self.bf16_list.remove('matmul_v2')
7171

7272
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
73-
custom_fp32_list={'elementwise_add'})
73+
custom_fp32_list={'matmul_v2'})
7474

7575
def test_amp_lists_6(self):
7676
# 6. w=None, b={'lstm'}

python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,28 @@
1919
import contextlib
2020
import unittest
2121
import numpy as np
22+
import struct
2223
import paddle.fluid.layers as layers
2324
import paddle.static.amp as amp
2425
from paddle.fluid import core
2526

2627
paddle.enable_static()
2728

2829

30+
def convert_uint16_to_float(in_list):
31+
if in_list.dtype == np.uint16:
32+
in_list = np.asarray(in_list)
33+
out = np.vectorize(
34+
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
35+
otypes=[np.float32])(in_list.flat)
36+
return np.reshape(out, in_list.shape)
37+
else:
38+
return in_list
39+
40+
41+
cutf = convert_uint16_to_float
42+
43+
2944
@unittest.skipIf(not core.supports_bfloat16(),
3045
"place does not support BF16 evaluation")
3146
class TestModelCastBF16(unittest.TestCase):
@@ -111,10 +126,13 @@ def _graph_common(self, _amp_fun, startup_prog=None):
111126
'tt_bf16': nn_bf16,
112127
},
113128
fetch_list=[ret_bf16, ret, ret_fp32bf16],
114-
amp_fun=lambda prog: amp.bf16.rewrite_program_bf16(prog))
129+
amp_fun=_amp_fun,
130+
startup_prog=startup_prog)
115131

116-
self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2))
117-
self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2))
132+
self.assertTrue(
133+
np.allclose(cutf(static_ret_bf16), cutf(static_ret), 1e-2))
134+
self.assertTrue(
135+
np.allclose(cutf(static_ret_bf16), cutf(ret_fp32bf16), 1e-2))
118136

119137
with self.static_graph():
120138
t = layers.data(name='t', shape=[size, size], dtype='float32')
@@ -141,6 +159,7 @@ def test_graph_rewrite(self):
141159
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
142160
prog,
143161
amp.bf16.AutoMixedPrecisionListsBF16(
162+
custom_bf16_list={'elementwise_add'},
144163
custom_fp32_varnames={'elementwise_add_0.tmp_0'})
145164
))
146165

@@ -149,6 +168,7 @@ def test_graph_cast(self):
149168
prog,
150169
startup_prog,
151170
amp.bf16.AutoMixedPrecisionListsBF16(
171+
custom_bf16_list={'elementwise_add'},
152172
custom_fp32_list={'elementwise_mul'}),
153173
use_bf16_guard=True
154174
), startup_prog=fluid.default_startup_program())

0 commit comments

Comments
 (0)