Skip to content

Commit 7a4715c

Browse files
authored
Support PT2E save and load (#1918)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 34f0a9f commit 7a4715c

File tree

5 files changed

+69
-3
lines changed

5 files changed

+69
-3
lines changed

neural_compressor/torch/algorithms/pt2e_quant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515

1616
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
17+
from .save_load import save, load
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import os
17+
18+
import torch
19+
20+
from neural_compressor.common.utils import load_config_mapping, save_config_mapping
21+
from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger
22+
23+
24+
def save(model, example_inputs, output_dir="./saved_results"):
25+
os.makedirs(output_dir, exist_ok=True)
26+
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
27+
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
28+
quantized_ep = torch.export.export(model, example_inputs)
29+
torch.export.save(quantized_ep, qmodel_file_path)
30+
for key, op_config in model.qconfig.items():
31+
model.qconfig[key] = op_config.to_dict()
32+
with open(qconfig_file_path, "w") as f:
33+
json.dump(model.qconfig, f, indent=4)
34+
35+
logger.info("Save quantized model to {}.".format(qmodel_file_path))
36+
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))
37+
38+
39+
def load(output_dir="./saved_results"):
40+
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
41+
loaded_quantized_ep = torch.export.load(qmodel_file_path)
42+
return loaded_quantized_ep.module()

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def static_quant_entry(
210210
def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
211211
logger.info("Quantize model with the PT2E static quant algorithm.")
212212
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
213+
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save
213214

214215
run_fn = kwargs.get("run_fn", None)
215216
example_inputs = kwargs.get("example_inputs", None)
@@ -221,6 +222,8 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
221222
model = w8a8_quantizer.execute(
222223
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
223224
)
225+
model.qconfig = configs_mapping
226+
model.save = MethodType(save, model)
224227
return model
225228

226229

@@ -230,6 +233,7 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
230233
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
231234
logger.info("Quantize model with the PT2E static quant algorithm.")
232235
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
236+
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save
233237

234238
run_fn = kwargs.get("run_fn", None)
235239
example_inputs = kwargs.get("example_inputs", None)
@@ -240,6 +244,8 @@ def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode,
240244
model = w8a8_quantizer.execute(
241245
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
242246
)
247+
model.qconfig = configs_mapping
248+
model.save = MethodType(save, model)
243249
return model
244250

245251

neural_compressor/torch/quantization/load_entry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
8585
from neural_compressor.torch.algorithms import static_quant
8686

8787
return static_quant.load(model_name_or_path)
88+
elif "static_quant" in per_op_qconfig.keys() or "pt2e_dynamic_quant" in per_op_qconfig.keys(): # PT2E
89+
from neural_compressor.torch.algorithms import pt2e_quant
90+
91+
return pt2e_quant.load(model_name_or_path)
8892
else:
8993
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
9094
# select load function
@@ -102,6 +106,7 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
102106
from neural_compressor.torch.algorithms import habana_fp8
103107

104108
return habana_fp8.load(model_name_or_path, original_model)
109+
105110
elif format == LoadFormat.HUGGINGFACE.value:
106111
# now only support load huggingface WOQ causal language model
107112
from neural_compressor.torch.algorithms import weight_only

test/3x/torch/quantization/test_pt2e_quant.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import os
2-
import unittest
3-
from unittest.mock import patch
1+
import shutil
42

53
import pytest
64
import torch
@@ -33,6 +31,8 @@ def _is_ipex_imported():
3331

3432

3533
class TestPT2EQuantization:
34+
def teardown_class(self):
35+
shutil.rmtree("saved_results", ignore_errors=True)
3636

3737
@staticmethod
3838
def get_toy_model():
@@ -114,6 +114,18 @@ def calib_fn(model):
114114
config.freezing = True
115115
q_model_out = q_model(*example_inputs)
116116
assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!"
117+
118+
# test save and load
119+
q_model.save(
120+
example_inputs=example_inputs,
121+
output_dir="./saved_results",
122+
)
123+
from neural_compressor.torch.quantization import load
124+
125+
loaded_quantized_model = load("./saved_results")
126+
loaded_q_model_out = loaded_quantized_model(*example_inputs)
127+
assert torch.equal(loaded_q_model_out, q_model_out)
128+
117129
opt_model = torch.compile(q_model)
118130
out = opt_model(*example_inputs)
119131
logger.warning("out shape is %s", out.shape)

0 commit comments

Comments
 (0)