Skip to content

Commit 3fe2fd9

Browse files
xin3heXuehaoSun
andauthored
fix bf16 symbolic_trace bug (#1892)
Description: fix bf16 symbolic_trace bug, - cause abnormal recursive calling. - missing necessary attributes - By moving BF16 fallback ahead of quantization and removing bf16_symbolic_trace, we fix it. --------- Signed-off-by: xin3he <xin3.he@intel.com> Co-authored-by: Sun, Xuehao <xuehao.sun@intel.com>
1 parent e080e06 commit 3fe2fd9

File tree

4 files changed

+20
-49
lines changed

4 files changed

+20
-49
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,8 @@ def _combine_capability(self, bf16_ops, q_capability):
12421242
q_capability["opwise"][bf16_op] = [bf16_config, fp32_config]
12431243
if bf16_op[1] not in q_capability["optypewise"]:
12441244
q_capability["optypewise"][bf16_op[1]] = [bf16_config, fp32_config]
1245+
if bf16_op[1] in q_capability["optypewise"] and bf16_config not in q_capability["optypewise"][bf16_op[1]]:
1246+
q_capability["optypewise"][bf16_op[1]].append(bf16_config)
12451247
return q_capability
12461248

12471249
def get_fused_list(self, model):
@@ -3579,6 +3581,16 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
35793581
return q_model
35803582

35813583
self.tune_cfg["fx_sub_module_list"] = self.sub_module_list
3584+
3585+
# BF16 fallback
3586+
if (
3587+
len(self.tune_cfg["bf16_ops_list"]) > 0
3588+
and self.version.release >= Version("1.11.0").release
3589+
and self.use_bf16
3590+
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
3591+
): # pragma: no cover
3592+
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)
3593+
35823594
if self.approach == "quant_aware_training":
35833595
q_model._model.train()
35843596
if self.sub_module_list is None:
@@ -3665,14 +3677,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
36653677
self.sub_module_list, q_model._model, prefix="", custom_config=self.prepare_custom_config_dict
36663678
)
36673679

3668-
if (
3669-
len(self.tune_cfg["bf16_ops_list"]) > 0
3670-
and self.version.release >= Version("1.11.0").release
3671-
and self.use_bf16
3672-
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
3673-
): # pragma: no cover
3674-
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)
3675-
36763680
self.fused_dict = self.get_fused_list(q_model.model)
36773681
q_model.is_quantized = True
36783682
q_model.q_config = copy.deepcopy(self.tune_cfg)

neural_compressor/adaptor/pytorch_cpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
name: '1.11'
2020

2121
bf16: ['Linear', 'bmm', 'mm', 'baddbmm', 'addmm', 'addbmm',
22-
'_convolution', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell']
22+
'Conv1d', 'Conv2d', 'Conv3d', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell']
2323
fp32: ['*'] # `*` means all op types.
2424
int8: &1_11_capabilities {
2525
'static': &cap_s8_1_11 {

neural_compressor/adaptor/torch_utils/bf16_convert.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"""Bf16 Convert for Torch Utils."""
1818
import torch
1919
import torch.nn as nn
20-
from torch.fx import symbolic_trace
2120

2221
from ...utils import logger
2322

@@ -28,6 +27,7 @@ class BF16ModuleWrapper(nn.Module):
2827
def __init__(self, module):
2928
"""Init a BF16ModuleWrapper object."""
3029
super(BF16ModuleWrapper, self).__init__()
30+
module = module.bfloat16()
3131
self.add_module("module", module)
3232
self.train(module.training)
3333
# WA for TransformerEncoder to access its Linear's weights and bias
@@ -38,7 +38,6 @@ def __init__(self, module):
3838
def forward(self, X):
3939
"""Convert dtype."""
4040
X = X.to(torch.bfloat16)
41-
self.module.bfloat16()
4241
X = self.module(X)
4342
return X.float()
4443

@@ -54,44 +53,18 @@ def Convert(model, tune_cfg):
5453
mixed_precision_model (object): model with mixed precision.
5554
"""
5655
bf16_ops_list = tune_cfg["bf16_ops_list"]
57-
fx_sub_module_list = tune_cfg["fx_sub_module_list"] if "fx_sub_module_list" in tune_cfg.keys() else []
5856
if len(bf16_ops_list) > 0:
5957
logger.info("Convert operators to bfloat16")
6058
mixed_precision_model = _bf16_wrapper_model(model, bf16_ops_list)
61-
if fx_sub_module_list is not None and len(fx_sub_module_list) > 0:
62-
mixed_precision_model = bf16_symbolic_trace(mixed_precision_model, fx_sub_module_list)
6359
return mixed_precision_model
6460

6561

6662
def _bf16_wrapper_model(model, bf16_ops_list, prefix=""):
6763
for name, child in model.named_children():
6864
op_name = prefix + "." + name if prefix != "" else name
6965
for bf16_op_name in bf16_ops_list:
70-
if op_name == bf16_op_name[0]:
66+
if op_name == bf16_op_name[0] or op_name == bf16_op_name[0].split(".module")[0]:
7167
child = BF16ModuleWrapper(child)
72-
else:
73-
_bf16_wrapper_model(child, bf16_ops_list, op_name)
74-
setattr(model, name, child)
75-
return model
76-
77-
78-
def bf16_symbolic_trace(model, fx_sub_module_list, prefix=""):
79-
"""Symbolic trace for bf16 models.
80-
81-
Args:
82-
model (object): the input model.
83-
fx_sub_module_list (list): _description_
84-
prefix (str): prefix of op name.
85-
86-
Returns:
87-
model (object)
88-
"""
89-
for name, child in model.named_children():
90-
op_name = prefix + "." + name if prefix != "" else name
91-
for fx_sub_module_name in fx_sub_module_list:
92-
if op_name == fx_sub_module_name:
93-
child = symbolic_trace(child)
94-
else:
95-
bf16_symbolic_trace(child, fx_sub_module_list, op_name)
96-
setattr(model, name, child)
68+
setattr(model, name, child)
69+
_bf16_wrapper_model(child, bf16_ops_list, op_name)
9770
return model

test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2x.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -392,21 +392,15 @@ def test_fx_sub_module_quant(self):
392392
"Please use PyTroch 1.11 or higher version for mixed precision with pytorch_fx or pytorch backend",
393393
)
394394
def test_mix_precision(self):
395+
os.environ["FORCE_BF16"] = "1"
395396
model_origin = DynamicControlModel()
396-
# run fx_quant in neural_compressor and save the quantized GraphModule
397397
dataset = Datasets("pytorch")["dummy"]((100, 3, 224, 224))
398398
dataloader = DataLoader("pytorch", dataset)
399399
set_workspace("./saved")
400+
# fx mode usually has .module suffix due to tracing of the entire model fails, so use conv.* to leverage re.match
401+
ptq_fx_op_name_list["conv.*"] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}}
400402
conf = PostTrainingQuantConfig(op_name_dict=ptq_fx_op_name_list)
401403
q_model = quantization.fit(model_origin, conf, calib_dataloader=dataloader, calib_func=eval_func)
402-
tune_cfg = q_model.q_config
403-
tune_cfg["op"][("conv.module", "Conv2d")].clear()
404-
tune_cfg["op"][("conv.module", "Conv2d")] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}}
405-
tune_cfg["bf16_ops_list"].append(("conv.module", "Conv2d"))
406-
from neural_compressor.adaptor.torch_utils.bf16_convert import Convert
407-
408-
q_model._model = Convert(q_model._model, tune_cfg)
409-
410404
self.assertEqual(q_model._model.conv.module.module.weight.dtype, torch.bfloat16)
411405
self.assertEqual(q_model._model.conv.module.module.bias.dtype, torch.bfloat16)
412406

0 commit comments

Comments
 (0)