Skip to content

fix bf16 symbolic_trace bug #1892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,8 @@ def _combine_capability(self, bf16_ops, q_capability):
q_capability["opwise"][bf16_op] = [bf16_config, fp32_config]
if bf16_op[1] not in q_capability["optypewise"]:
q_capability["optypewise"][bf16_op[1]] = [bf16_config, fp32_config]
if bf16_op[1] in q_capability["optypewise"] and bf16_config not in q_capability["optypewise"][bf16_op[1]]:
q_capability["optypewise"][bf16_op[1]].append(bf16_config)
return q_capability

def get_fused_list(self, model):
Expand Down Expand Up @@ -3579,6 +3581,16 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
return q_model

self.tune_cfg["fx_sub_module_list"] = self.sub_module_list

# BF16 fallback
if (
len(self.tune_cfg["bf16_ops_list"]) > 0
and self.version.release >= Version("1.11.0").release
and self.use_bf16
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
): # pragma: no cover
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)

if self.approach == "quant_aware_training":
q_model._model.train()
if self.sub_module_list is None:
Expand Down Expand Up @@ -3665,14 +3677,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
self.sub_module_list, q_model._model, prefix="", custom_config=self.prepare_custom_config_dict
)

if (
len(self.tune_cfg["bf16_ops_list"]) > 0
and self.version.release >= Version("1.11.0").release
and self.use_bf16
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
): # pragma: no cover
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)

self.fused_dict = self.get_fused_list(q_model.model)
q_model.is_quantized = True
q_model.q_config = copy.deepcopy(self.tune_cfg)
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/pytorch_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
name: '1.11'

bf16: ['Linear', 'bmm', 'mm', 'baddbmm', 'addmm', 'addbmm',
'_convolution', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell']
'Conv1d', 'Conv2d', 'Conv3d', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell']
fp32: ['*'] # `*` means all op types.
int8: &1_11_capabilities {
'static': &cap_s8_1_11 {
Expand Down
35 changes: 4 additions & 31 deletions neural_compressor/adaptor/torch_utils/bf16_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""Bf16 Convert for Torch Utils."""
import torch
import torch.nn as nn
from torch.fx import symbolic_trace

from ...utils import logger

Expand All @@ -28,6 +27,7 @@ class BF16ModuleWrapper(nn.Module):
def __init__(self, module):
"""Init a BF16ModuleWrapper object."""
super(BF16ModuleWrapper, self).__init__()
module = module.bfloat16()
self.add_module("module", module)
self.train(module.training)
# WA for TransformerEncoder to access its Linear's weights and bias
Expand All @@ -38,7 +38,6 @@ def __init__(self, module):
def forward(self, X):
"""Convert dtype."""
X = X.to(torch.bfloat16)
self.module.bfloat16()
X = self.module(X)
return X.float()

Expand All @@ -54,44 +53,18 @@ def Convert(model, tune_cfg):
mixed_precision_model (object): model with mixed precision.
"""
bf16_ops_list = tune_cfg["bf16_ops_list"]
fx_sub_module_list = tune_cfg["fx_sub_module_list"] if "fx_sub_module_list" in tune_cfg.keys() else []
if len(bf16_ops_list) > 0:
logger.info("Convert operators to bfloat16")
mixed_precision_model = _bf16_wrapper_model(model, bf16_ops_list)
if fx_sub_module_list is not None and len(fx_sub_module_list) > 0:
mixed_precision_model = bf16_symbolic_trace(mixed_precision_model, fx_sub_module_list)
return mixed_precision_model


def _bf16_wrapper_model(model, bf16_ops_list, prefix=""):
for name, child in model.named_children():
op_name = prefix + "." + name if prefix != "" else name
for bf16_op_name in bf16_ops_list:
if op_name == bf16_op_name[0]:
if op_name == bf16_op_name[0] or op_name == bf16_op_name[0].split(".module")[0]:
child = BF16ModuleWrapper(child)
else:
_bf16_wrapper_model(child, bf16_ops_list, op_name)
setattr(model, name, child)
return model


def bf16_symbolic_trace(model, fx_sub_module_list, prefix=""):
"""Symbolic trace for bf16 models.

Args:
model (object): the input model.
fx_sub_module_list (list): _description_
prefix (str): prefix of op name.

Returns:
model (object)
"""
for name, child in model.named_children():
op_name = prefix + "." + name if prefix != "" else name
for fx_sub_module_name in fx_sub_module_list:
if op_name == fx_sub_module_name:
child = symbolic_trace(child)
else:
bf16_symbolic_trace(child, fx_sub_module_list, op_name)
setattr(model, name, child)
setattr(model, name, child)
_bf16_wrapper_model(child, bf16_ops_list, op_name)
return model
12 changes: 3 additions & 9 deletions test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,21 +392,15 @@ def test_fx_sub_module_quant(self):
"Please use PyTroch 1.11 or higher version for mixed precision with pytorch_fx or pytorch backend",
)
def test_mix_precision(self):
os.environ["FORCE_BF16"] = "1"
model_origin = DynamicControlModel()
# run fx_quant in neural_compressor and save the quantized GraphModule
dataset = Datasets("pytorch")["dummy"]((100, 3, 224, 224))
dataloader = DataLoader("pytorch", dataset)
set_workspace("./saved")
# fx mode usually has .module suffix due to tracing of the entire model fails, so use conv.* to leverage re.match
ptq_fx_op_name_list["conv.*"] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}}
conf = PostTrainingQuantConfig(op_name_dict=ptq_fx_op_name_list)
q_model = quantization.fit(model_origin, conf, calib_dataloader=dataloader, calib_func=eval_func)
tune_cfg = q_model.q_config
tune_cfg["op"][("conv.module", "Conv2d")].clear()
tune_cfg["op"][("conv.module", "Conv2d")] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}}
tune_cfg["bf16_ops_list"].append(("conv.module", "Conv2d"))
from neural_compressor.adaptor.torch_utils.bf16_convert import Convert

q_model._model = Convert(q_model._model, tune_cfg)

self.assertEqual(q_model._model.conv.module.module.weight.dtype, torch.bfloat16)
self.assertEqual(q_model._model.conv.module.module.bias.dtype, torch.bfloat16)

Expand Down
Loading