Skip to content

Commit 9c6102b

Browse files
authored
Add save&load API to SmoothQuant ipex model (#1673)
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent e81a2dd commit 9c6102b

File tree

7 files changed

+207
-7
lines changed

7 files changed

+207
-7
lines changed

neural_compressor/torch/algorithms/smooth_quant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515

1616
from .utility import *
1717
from .smooth_quant import smooth_quantize
18+
from .save_load import save, load, recover_model_from_json
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint:disable=import-error
16+
import torch
17+
18+
try:
19+
import intel_extension_for_pytorch as ipex
20+
except:
21+
assert False, "Please install IPEX for smooth quantization."
22+
23+
from neural_compressor.torch.algorithms.static_quant import load, save
24+
25+
26+
def recover_model_from_json(model, json_file_path, example_inputs): # pragma: no cover
27+
"""Recover ipex model from JSON file.
28+
29+
Args:
30+
model (object): fp32 model need to do quantization.
31+
json_file_path (json): configuration JSON file for ipex.
32+
example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function.
33+
34+
Returns:
35+
(object): quantized model
36+
"""
37+
from torch.ao.quantization.observer import MinMaxObserver
38+
39+
if ipex.__version__ >= "2.1.100":
40+
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver)
41+
else:
42+
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver())
43+
if isinstance(example_inputs, dict):
44+
model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True)
45+
else:
46+
model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True)
47+
48+
model.load_qconf_summary(qconf_summary=json_file_path)
49+
model = ipex.quantization.convert(model, inplace=True)
50+
with torch.no_grad():
51+
try:
52+
if isinstance(example_inputs, dict):
53+
# pylint: disable=E1120,E1123
54+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
55+
else:
56+
model = torch.jit.trace(model, example_inputs)
57+
model = torch.jit.freeze(model.eval())
58+
except:
59+
if isinstance(example_inputs, dict):
60+
# pylint: disable=E1120,E1123
61+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
62+
else:
63+
model = torch.jit.trace(model, example_inputs, strict=False)
64+
model = torch.jit.freeze(model.eval())
65+
if isinstance(example_inputs, dict):
66+
model(**example_inputs)
67+
model(**example_inputs)
68+
elif isinstance(example_inputs, tuple) or isinstance(example_inputs, list):
69+
model(*example_inputs)
70+
model(*example_inputs)
71+
else:
72+
model(example_inputs)
73+
model(example_inputs)
74+
return model

neural_compressor/torch/algorithms/static_quant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515

1616
from .utility import *
1717
from .static_quant import static_quantize
18+
from .save_load import save, load
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint:disable=import-error
16+
import json
17+
import os
18+
19+
import torch
20+
21+
try:
22+
import intel_extension_for_pytorch as ipex
23+
except:
24+
assert False, "Please install IPEX for static quantization."
25+
26+
from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger
27+
28+
29+
def save(model, output_dir="./saved_results"):
30+
if not os.path.exists(output_dir):
31+
os.mkdir(output_dir)
32+
33+
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
34+
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
35+
model.ori_save(qmodel_file_path)
36+
with open(qconfig_file_path, "w") as f:
37+
json.dump(model.tune_cfg, f, indent=4)
38+
39+
logger.info("Save quantized model to {}.".format(qmodel_file_path))
40+
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))
41+
42+
43+
def load(output_dir="./saved_results"):
44+
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
45+
model = torch.jit.load(qmodel_file_path)
46+
model = torch.jit.freeze(model.eval())
47+
logger.info("Quantized model loading successful.")
48+
return model

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def static_quant_entry(
121121
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], StaticQuantConfig], *args, **kwargs
122122
) -> torch.nn.Module:
123123
logger.info("Quantize model with the static quant algorithm.")
124-
from neural_compressor.torch.algorithms.static_quant import static_quantize
124+
from neural_compressor.torch.algorithms.static_quant import save, static_quantize
125125

126126
# convert the user config into internal format
127127
quant_config_mapping = {}
@@ -157,6 +157,8 @@ def static_quant_entry(
157157
inplace=inplace,
158158
)
159159
logger.info("Static quantization done.")
160+
q_model.ori_save = q_model.save
161+
q_model.save = MethodType(save, q_model)
160162
return q_model
161163

162164

@@ -167,7 +169,7 @@ def smooth_quant_entry(
167169
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], SmoothQuantConfig], *args, **kwargs
168170
) -> torch.nn.Module:
169171
logger.info("Quantize model with the smooth quant algorithm.")
170-
from neural_compressor.torch.algorithms.smooth_quant import smooth_quantize
172+
from neural_compressor.torch.algorithms.smooth_quant import save, smooth_quantize
171173

172174
# convert the user config into internal format
173175
quant_config_mapping = {}
@@ -214,6 +216,8 @@ def smooth_quant_entry(
214216
inplace=inplace,
215217
)
216218
logger.info("Smooth quantization done.")
219+
q_model.ori_save = q_model.save
220+
q_model.save = MethodType(save, q_model)
217221
return q_model
218222

219223

test/3x/torch/quantization/test_smooth_quant.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def run_fn(model):
8484
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."
8585

8686
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
87-
def test_sq_ipex_save_load(self):
87+
def test_sq_ipex_accuracy(self):
8888
from intel_extension_for_pytorch.quantization import convert, prepare
8989

9090
example_inputs = torch.zeros([1, 3])
@@ -96,6 +96,7 @@ def run_fn(model):
9696
model(example_inputs)
9797

9898
run_fn(user_model)
99+
user_model.save_qconf_summary(qconf_summary="ipex.json")
99100
with torch.no_grad():
100101
user_model = convert(user_model.eval(), inplace=True).eval()
101102
user_model(example_inputs)
@@ -109,12 +110,38 @@ def run_fn(model):
109110
quant_config = get_default_sq_config()
110111
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
111112
assert q_model is not None, "Quantization failed!"
113+
q_model.save("saved_results")
114+
115+
inc_out = q_model(example_inputs)
116+
# set a big atol to avoid random issue
117+
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."
118+
119+
from neural_compressor.torch.algorithms.smooth_quant import recover_model_from_json
120+
121+
fp32_model = copy.deepcopy(model)
122+
ipex_model = recover_model_from_json(fp32_model, "ipex.json", example_inputs=example_inputs)
123+
ipex_out = ipex_model(example_inputs)
124+
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."
125+
126+
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
127+
def test_sq_save_load(self):
128+
fp32_model = copy.deepcopy(model)
129+
quant_config = get_default_sq_config()
130+
example_inputs = torch.zeros([1, 3])
131+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
132+
assert q_model is not None, "Quantization failed!"
133+
q_model.save("saved_results")
112134
inc_out = q_model(example_inputs)
113-
q_model.save("saved")
114135

115-
# load
116-
loaded_model = torch.jit.load("saved")
136+
from neural_compressor.torch.algorithms.smooth_quant import load, recover_model_from_json
137+
138+
# load using saved model
139+
loaded_model = load("saved_results")
117140
loaded_out = loaded_model(example_inputs)
118-
assert torch.allclose(inc_out, ipex_out, atol=1e-05), "Unexpected result. Please double check."
141+
# set a big atol to avoid random issue
142+
assert torch.allclose(inc_out, loaded_out, atol=2e-02), "Unexpected result. Please double check."
119143

144+
# compare saved json file
145+
loaded_model = recover_model_from_json(fp32_model, "saved_results/qconfig.json", example_inputs=example_inputs)
146+
loaded_out = loaded_model(example_inputs)
120147
assert torch.allclose(inc_out, loaded_out, atol=1e-05), "Unexpected result. Please double check."

test/3x/torch/quantization/test_static_quant.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,48 @@ def run_fn(model):
9292
output2 = q_model(example_inputs)
9393
# set a big atol to avoid random issue
9494
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."
95+
96+
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
97+
def test_static_quant_save_load(self):
98+
from intel_extension_for_pytorch.quantization import convert, prepare
99+
100+
example_inputs = torch.zeros(1, 30)
101+
try:
102+
qconfig = ipex.quantization.default_static_qconfig_mapping
103+
except:
104+
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
105+
106+
qconfig = QConfig(
107+
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
108+
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
109+
)
110+
user_model = copy.deepcopy(self.fp32_model)
111+
user_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True)
112+
113+
def run_fn(model):
114+
model(example_inputs)
115+
116+
run_fn(user_model)
117+
with torch.no_grad():
118+
user_model = convert(user_model.eval(), inplace=True).eval()
119+
user_model(example_inputs)
120+
user_model = torch.jit.trace(user_model.eval(), example_inputs, strict=False)
121+
user_model = torch.jit.freeze(user_model.eval())
122+
user_model(example_inputs)
123+
user_model(example_inputs)
124+
ipex_out = user_model(example_inputs)
125+
126+
fp32_model = copy.deepcopy(self.fp32_model)
127+
quant_config = get_default_static_config()
128+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
129+
assert q_model is not None, "Quantization failed!"
130+
inc_out = q_model(example_inputs)
131+
# set a big atol to avoid random issue
132+
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."
133+
q_model.save("saved_results")
134+
135+
from neural_compressor.torch.algorithms.static_quant import load
136+
137+
# load
138+
loaded_model = load("saved_results")
139+
assert isinstance(loaded_model, torch.jit.ScriptModule)

0 commit comments

Comments
 (0)