Skip to content

Commit f88af20

Browse files
authored
Combine amp and qat (#33484)
* Combine amp and qat * add unit test
1 parent 0905dee commit f88af20

File tree

5 files changed

+267
-19
lines changed

5 files changed

+267
-19
lines changed

paddle/fluid/imperative/amp_auto_cast.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToFP32(
141141
}
142142

143143
static inline framework::proto::VarType::Type GetPromoteType(
144-
const NameVarBaseMap& ins) {
144+
const std::string& op_type, const NameVarBaseMap& ins) {
145145
auto dst_type = framework::proto::VarType::FP16;
146146
for (const auto& pair : ins) {
147147
for (const auto& var : pair.second) {
@@ -151,6 +151,18 @@ static inline framework::proto::VarType::Type GetPromoteType(
151151
}
152152
}
153153
}
154+
155+
// NOTE(juncai): moving_average_abs_max_scale only consider the
156+
// dtype of input(X)
157+
if (op_type == "moving_average_abs_max_scale") {
158+
for (const auto& pair : ins) {
159+
if (pair.first == "X" &&
160+
pair.second.front()->DataType() == framework::proto::VarType::FP16) {
161+
dst_type = framework::proto::VarType::FP16;
162+
}
163+
}
164+
}
165+
154166
return dst_type;
155167
}
156168

@@ -183,7 +195,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
183195
}
184196
return new_ins;
185197
} else {
186-
auto dst_type = GetPromoteType(ins);
198+
auto dst_type = GetPromoteType(op_type, ins);
199+
187200
// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
188201
if (dst_type == framework::proto::VarType::FP16 &&
189202
AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(

paddle/fluid/operators/fake_quantize_op.cu

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
2525
int bid = threadIdx.x + blockIdx.x * blockDim.x;
2626
int tid = threadIdx.x;
2727

28-
extern __shared__ T shared_max_data[];
28+
extern __shared__ char* shared_max_data_tmp[];
29+
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
2930
if (gridDim.x > 1) {
3031
shared_max_data[tid] = T(0);
3132
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
32-
T tmp = fabs(in[i]);
33+
T tmp = abs(in[i]);
3334
if (tmp > shared_max_data[tid]) {
3435
shared_max_data[tid] = tmp;
3536
}
3637
}
3738
} else {
3839
if (bid < n) {
39-
shared_max_data[tid] = fabs(in[bid]);
40+
shared_max_data[tid] = abs(in[bid]);
4041
} else {
4142
shared_max_data[tid] = T(0);
4243
}
@@ -73,6 +74,8 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
7374
};
7475

7576
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
77+
template struct FindAbsMaxFunctor<platform::CUDADeviceContext,
78+
paddle::platform::float16>;
7679

7780
template <typename T>
7881
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
@@ -213,13 +216,16 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
213216
int tid = threadIdx.x;
214217

215218
T s = scale[0];
216-
T inv_s = inverse(s);
219+
T bin_cnt_t = static_cast<T>(bin_cnt);
220+
217221
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
218222
T x = in[i];
219-
T v = x > s ? s : x;
220-
v = v < -s ? -s : v;
221-
v = bin_cnt * inv_s * v;
222-
out[i] = round(v) * s / bin_cnt;
223+
x = x > s ? s : x;
224+
x = x < -s ? -s : x;
225+
x = (bin_cnt_t / s) * x;
226+
227+
x = static_cast<T>(round(static_cast<float>(x)));
228+
out[i] = (x * s) / bin_cnt_t;
223229
}
224230
}
225231

@@ -261,9 +267,6 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
261267
}
262268
};
263269

264-
template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext,
265-
float>;
266-
267270
// ChannelClipAndQuantKernel for quant_axis is 0
268271
template <typename T>
269272
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
@@ -423,8 +426,10 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
423426
memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
424427
ctx.stream());
425428
ctx.Wait();
426-
state = rate * state + 1;
427-
accum = rate * accum + scale;
429+
430+
T rate_t = static_cast<T>(rate);
431+
state = rate_t * state + static_cast<T>(1.0);
432+
accum = rate_t * accum + scale;
428433
scale = accum / state;
429434

430435
memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place),
@@ -527,10 +532,12 @@ template struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext,
527532

528533
namespace ops = paddle::operators;
529534
using CUDA = paddle::platform::CUDADeviceContext;
535+
using float16 = paddle::platform::float16;
530536
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
531537
ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
532538
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max,
533-
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>);
539+
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>,
540+
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float16>);
534541
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
535542
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
536543
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
@@ -539,12 +546,15 @@ REGISTER_OP_CUDA_KERNEL(
539546
fake_quantize_moving_average_abs_max,
540547
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
541548
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
542-
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>);
549+
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>,
550+
ops::MovingAverageAbsMaxScaleKernel<CUDA, float16>);
543551
REGISTER_OP_CUDA_KERNEL(
544552
fake_quantize_dequantize_moving_average_abs_max,
545-
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
553+
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>,
554+
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float16>);
546555
REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad,
547-
ops::StrightThroughEstimatorGradKernel<CUDA, float>);
556+
ops::StrightThroughEstimatorGradKernel<CUDA, float>,
557+
ops::StrightThroughEstimatorGradKernel<CUDA, float16>);
548558
REGISTER_OP_CUDA_KERNEL(
549559
fake_channel_wise_quantize_dequantize_abs_max,
550560
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);

python/paddle/fluid/contrib/slim/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ if(WIN32)
127127
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
128128
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
129129
list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2)
130+
list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp)
130131
endif()
131132

132133
if(LINUX AND WITH_MKLDNN)
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
2+
#
3+
# licensed under the apache license, version 2.0 (the "license");
4+
# you may not use this file except in compliance with the license.
5+
# you may obtain a copy of the license at
6+
#
7+
# http://www.apache.org/licenses/license-2.0
8+
#
9+
# unless required by applicable law or agreed to in writing, software
10+
# distributed under the license is distributed on an "as is" basis,
11+
# without warranties or conditions of any kind, either express or implied.
12+
# see the license for the specific language governing permissions and
13+
# limitations under the license.
14+
15+
from __future__ import print_function
16+
17+
import os
18+
import numpy as np
19+
import random
20+
import shutil
21+
import time
22+
import unittest
23+
import logging
24+
25+
import paddle
26+
import paddle.fluid as fluid
27+
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
28+
from paddle.fluid.log_helper import get_logger
29+
from paddle.dataset.common import download
30+
31+
from imperative_test_utils import fix_model_dict, ImperativeLenet
32+
33+
os.environ["CPU_NUM"] = "1"
34+
if paddle.is_compiled_with_cuda():
35+
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
36+
37+
_logger = get_logger(
38+
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
39+
40+
41+
class TestImperativeQatAmp(unittest.TestCase):
42+
"""
43+
Test the combination of qat and amp.
44+
"""
45+
46+
@classmethod
47+
def setUpClass(cls):
48+
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
49+
cls.root_path = os.path.join(os.getcwd(),
50+
"imperative_qat_amp_" + timestamp)
51+
cls.save_path = os.path.join(cls.root_path, "model")
52+
53+
cls.download_path = 'dygraph_int8/download'
54+
cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
55+
cls.download_path)
56+
57+
cls.lenet_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/lenet_pretrained.tar.gz"
58+
cls.lenet_md5 = "953b802fb73b52fae42896e3c24f0afb"
59+
60+
seed = 1
61+
np.random.seed(seed)
62+
paddle.static.default_main_program().random_seed = seed
63+
paddle.static.default_startup_program().random_seed = seed
64+
65+
@classmethod
66+
def tearDownClass(cls):
67+
try:
68+
shutil.rmtree(cls.root_path)
69+
except Exception as e:
70+
print("Failed to delete {} due to {}".format(cls.root_path, str(e)))
71+
72+
def cache_unzipping(self, target_folder, zip_path):
73+
if not os.path.exists(target_folder):
74+
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
75+
zip_path)
76+
os.system(cmd)
77+
78+
def download_model(self, data_url, data_md5, folder_name):
79+
download(data_url, self.download_path, data_md5)
80+
file_name = data_url.split('/')[-1]
81+
zip_path = os.path.join(self.cache_folder, file_name)
82+
print('Data is downloaded at {0}'.format(zip_path))
83+
84+
data_cache_folder = os.path.join(self.cache_folder, folder_name)
85+
self.cache_unzipping(data_cache_folder, zip_path)
86+
return data_cache_folder
87+
88+
def set_vars(self):
89+
self.qat = ImperativeQuantAware()
90+
91+
self.train_batch_num = 30
92+
self.train_batch_size = 32
93+
self.test_batch_num = 100
94+
self.test_batch_size = 32
95+
self.eval_acc_top1 = 0.99
96+
97+
def model_train(self, model, batch_num=-1, batch_size=32, use_amp=False):
98+
model.train()
99+
100+
train_reader = paddle.batch(
101+
paddle.dataset.mnist.train(), batch_size=batch_size)
102+
adam = paddle.optimizer.Adam(
103+
learning_rate=0.001, parameters=model.parameters())
104+
scaler = paddle.amp.GradScaler(init_loss_scaling=500)
105+
106+
for batch_id, data in enumerate(train_reader()):
107+
x_data = np.array([x[0].reshape(1, 28, 28)
108+
for x in data]).astype('float32')
109+
y_data = np.array(
110+
[x[1] for x in data]).astype('int64').reshape(-1, 1)
111+
112+
img = paddle.to_tensor(x_data)
113+
label = paddle.to_tensor(y_data)
114+
115+
if use_amp:
116+
with paddle.amp.auto_cast():
117+
out = model(img)
118+
acc = fluid.layers.accuracy(out, label)
119+
loss = fluid.layers.cross_entropy(out, label)
120+
avg_loss = fluid.layers.mean(loss)
121+
scaled_loss = scaler.scale(avg_loss)
122+
scaled_loss.backward()
123+
124+
scaler.minimize(adam, scaled_loss)
125+
adam.clear_gradients()
126+
else:
127+
out = model(img)
128+
acc = fluid.layers.accuracy(out, label)
129+
loss = fluid.layers.cross_entropy(out, label)
130+
avg_loss = fluid.layers.mean(loss)
131+
avg_loss.backward()
132+
133+
adam.minimize(avg_loss)
134+
model.clear_gradients()
135+
136+
if batch_id % 100 == 0:
137+
_logger.info("Train | step {}: loss = {:}, acc= {:}".format(
138+
batch_id, avg_loss.numpy(), acc.numpy()))
139+
140+
if batch_num > 0 and batch_id + 1 >= batch_num:
141+
break
142+
143+
def model_test(self, model, batch_num=-1, batch_size=32, use_amp=False):
144+
model.eval()
145+
146+
test_reader = paddle.batch(
147+
paddle.dataset.mnist.test(), batch_size=batch_size)
148+
149+
acc_top1_list = []
150+
for batch_id, data in enumerate(test_reader()):
151+
x_data = np.array([x[0].reshape(1, 28, 28)
152+
for x in data]).astype('float32')
153+
y_data = np.array(
154+
[x[1] for x in data]).astype('int64').reshape(-1, 1)
155+
156+
img = paddle.to_tensor(x_data)
157+
label = paddle.to_tensor(y_data)
158+
159+
with paddle.amp.auto_cast(use_amp):
160+
out = model(img)
161+
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
162+
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
163+
164+
acc_top1_list.append(float(acc_top1.numpy()))
165+
if batch_id % 100 == 0:
166+
_logger.info("Test | At step {}: acc1 = {:}, acc5 = {:}".format(
167+
batch_id, acc_top1.numpy(), acc_top5.numpy()))
168+
169+
if batch_num > 0 and batch_id + 1 >= batch_num:
170+
break
171+
172+
acc_top1 = sum(acc_top1_list) / len(acc_top1_list)
173+
return acc_top1
174+
175+
def test_ptq(self):
176+
start_time = time.time()
177+
178+
self.set_vars()
179+
180+
params_path = self.download_model(self.lenet_url, self.lenet_md5,
181+
"lenet")
182+
params_path += "/lenet_pretrained/lenet.pdparams"
183+
184+
with fluid.dygraph.guard():
185+
model = ImperativeLenet()
186+
model_state_dict = paddle.load(params_path)
187+
model.set_state_dict(model_state_dict)
188+
189+
_logger.info("Test fp32 model")
190+
fp32_acc_top1 = self.model_test(model, self.test_batch_num,
191+
self.test_batch_size)
192+
193+
self.qat.quantize(model)
194+
195+
use_amp = True
196+
self.model_train(model, self.train_batch_num, self.train_batch_size,
197+
use_amp)
198+
199+
_logger.info("Test int8 model")
200+
int8_acc_top1 = self.model_test(model, self.test_batch_num,
201+
self.test_batch_size, use_amp)
202+
203+
_logger.info('fp32_acc_top1: %f, int8_acc_top1: %f' %
204+
(fp32_acc_top1, int8_acc_top1))
205+
self.assertTrue(
206+
int8_acc_top1 > fp32_acc_top1 - 0.01,
207+
msg='fp32_acc_top1: %f, int8_acc_top1: %f' %
208+
(fp32_acc_top1, int8_acc_top1))
209+
210+
input_spec = [
211+
paddle.static.InputSpec(
212+
shape=[None, 1, 28, 28], dtype='float32')
213+
]
214+
paddle.jit.save(layer=model, path=self.save_path, input_spec=input_spec)
215+
print('Quantized model saved in {%s}' % self.save_path)
216+
217+
end_time = time.time()
218+
print("total time: %ss" % (end_time - start_time))
219+
220+
221+
if __name__ == '__main__':
222+
unittest.main()

python/paddle/fluid/dygraph/amp/auto_cast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
'matmul',
3030
'matmul_v2',
3131
'mul',
32+
'fake_quantize_dequantize_abs_max',
33+
'fake_quantize_dequantize_moving_average_abs_max',
3234
}
3335

3436
# The set of ops that support fp16 calculation and are considered numerically-

0 commit comments

Comments
 (0)