Skip to content

Fix transformers rtn layer-wise quant #2008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
619dec5
load unquant module
Kaihui-intel Sep 14, 2024
52762e3
skip empty module
Kaihui-intel Sep 14, 2024
d255278
load ln_
Kaihui-intel Sep 14, 2024
a515fd3
load module
Kaihui-intel Sep 19, 2024
cce5bf9
remove rtn lw hook
Kaihui-intel Sep 23, 2024
20576c6
remove breakpoint
Kaihui-intel Sep 23, 2024
91d561d
resolve confilict
Kaihui-intel Sep 23, 2024
db2f59f
add ut for use/no-use layer wise
Kaihui-intel Sep 23, 2024
e28bc9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
6448f7e
Merge branch 'kaihui/transformers_lw' of https://github.com/intel/neu…
Kaihui-intel Sep 23, 2024
429ef15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
e3320b2
add lw check before convert
Kaihui-intel Sep 23, 2024
162b7e6
fix lw check
Kaihui-intel Sep 23, 2024
9c3257d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
9b63b94
fix lw check
Kaihui-intel Sep 23, 2024
275c6d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
dc2c4b2
fix load oom
Kaihui-intel Sep 25, 2024
0e823b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
dea8512
update xpu model_type list
Kaihui-intel Sep 27, 2024
44c312d
fix llama3 oom
Kaihui-intel Sep 29, 2024
448fc0b
fix trust_remote_code
Kaihui-intel Sep 29, 2024
eb9dce3
resolve trust remote code
Kaihui-intel Sep 29, 2024
9292ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2024
fd9c387
fix empty model trust_remote_code
Kaihui-intel Sep 29, 2024
35e58e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2024
79361af
fix trust_remote_code
Kaihui-intel Sep 29, 2024
8f4cee3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2024
9b82256
resolve trust remote code
Kaihui-intel Sep 29, 2024
83332a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2024
4402d61
update params
Kaihui-intel Sep 29, 2024
ad22bb4
fix load empty model
Kaihui-intel Sep 29, 2024
43f095f
update readme
Kaihui-intel Sep 30, 2024
adaadce
Merge branch 'master' into kaihui/transformers_lw
XuehaoSun Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,18 @@ Pytorch and Intel-extension-for-pytorch version for intel GPU > 2.1 are required
```bash
pip install -r requirements_GPU.txt
pip install transformers==4.38.1 # llama use 4.38.1
source /opt/intel/oneapi/setvars.sh
git clone https://github.com/intel/intel-extension-for-pytorch.git ipex-gpu
cd ipex-gpu
git submodule update --init --recursive
export USE_AOT_DEVLIST='pvc,ats-m150'
export BUILD_WITH_CPU=OFF

export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/:$LD_LIBRARY_PATH
export OCL_ICD_VENDORS=/etc/OpenCL/vendors
export CCL_ROOT=${CONDA_PREFIX}
source /opt/intel/oneapi/setvars.sh --force
export LLM_ACC_TEST=1

python setup.py install
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
tokenizer.save_pretrained(args.output_dir)

enable_optimize_transformers = False
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen"]
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen", "phi3"]

if config.model_type in opt_gpu_model_type_list:
enable_optimize_transformers = True
Expand Down
10 changes: 3 additions & 7 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,16 @@ def convert(

if use_layer_wise:
from neural_compressor.common.utils import DEFAULT_WORKSPACE
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module, register_weight_hooks
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module

if model_path == "":
model_path = model.path
assert model_path, "model_path should not be None."
model_path = get_path(model_path)

register_weight_hooks(model, model_path, device=device, clean_weight=True)

for name, m in model.named_modules():

if use_layer_wise and len(list(m.named_children())) == 0:
load_module(model, name, model_path, device=device)
if not isinstance(m, supported_layers):
continue
if name in weight_config: # pragma: no cover
Expand Down Expand Up @@ -192,9 +191,6 @@ def convert(
logger.debug(f"RTN quantized module:{name, m}")
logger.debug(log_msg)

if use_layer_wise:
load_module(model, name, model_path, device=device)

# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
if is_transformers_imported():
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def load_empty_model(pretrained_model_name_or_path, cls=None, **kwargs):
if cls.__base__ == _BaseAutoModelClass:
config = AutoConfig.from_pretrained(path, **kwargs)
with init_empty_weights():
model = cls.from_config(config)
model = cls.from_config(config, **kwargs)
else: # pragma: no cover
config = cls.config_class.from_pretrained(path, **kwargs)
with init_empty_weights():
model = cls(config)
model = cls(config, **kwargs)
model.tie_weights()
model.eval()
model.path = pretrained_model_name_or_path
Expand Down
28 changes: 27 additions & 1 deletion neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
(RtnConfig, AwqConfig, TeqConfig, GPTQConfig, AutoRoundConfig),
):
logger.info("Applying Weight Only Quantization.")
if use_xpu:
# set use_layer_wise on client
if hasattr(quantization_config, "use_layer_wise"):
import neural_compressor.torch.utils as torch_utils

process_type = torch_utils.get_processor_type_from_user_config()
if process_type == torch_utils.ProcessorType.Client:
quantization_config.use_layer_wise = True

if hasattr(quantization_config, "use_layer_wise") and quantization_config.use_layer_wise:
from transformers.dynamic_module_utils import resolve_trust_remote_code

from neural_compressor.torch import load_empty_model

trust_remote_code = kwargs.get("trust_remote_code", None)
has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map
has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code,
pretrained_model_name_or_path,
has_local_code,
has_remote_code,
)

model = load_empty_model(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if use_cpu:
quantization_config.post_init_cpu()
elif use_xpu:
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
kwargs["low_cpu_mem_usage"] = True
kwargs["device_map"] = "cpu"
Expand Down
12 changes: 6 additions & 6 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _replace_linear(
"fp16": ipex.quantization.WoqLowpMode.FP16,
"int8": ipex.quantization.WoqLowpMode.INT8,
}

ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype[quantization_config.bits],
lowp_mode=compute_dtype[quantization_config.compute_dtype],
Expand Down Expand Up @@ -366,11 +365,6 @@ def convert_to_quantized_model(model, config, device="cpu"):

# mapping to INC config
dtype = "int4" if config.weight_dtype == "int4_fullrange" else config.weight_dtype
import neural_compressor.torch.utils as torch_utils

process_type = torch_utils.get_processor_type_from_user_config()
if process_type == torch_utils.ProcessorType.Client:
config.use_layer_wise = True
if config.quant_method.value == "rtn":
quant_config = RTNConfig(
dtype=dtype,
Expand Down Expand Up @@ -529,6 +523,12 @@ def convert_to_quantized_model(model, config, device="cpu"):
if orig_dtype != torch.float32:
q_model.to(dtype=orig_dtype)

if config.use_layer_wise and not (q_model.device == device or q_model.device.type == device):
logger.warning(
"Do not convert device to avoid out of memory. Recommend using saved quantized model to inference."
)
return q_model

return q_model.to(device)


Expand Down
33 changes: 33 additions & 0 deletions test/3x/torch/quantization/weight_only/test_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,39 @@ def test_save_load(self):
loaded_output = loaded_model(dummy_input)[0]
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."

def test_use_layer_wise(self):
model_name_or_path = self.model_name_or_path

fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
dummy_input = fp32_model.dummy_inputs["input_ids"]

# RTN
# use_layer_wise=True
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=woq_config,
)
woq_output = woq_model(dummy_input)[0]

# save
output_dir = "./transformers_tmp"
woq_model.save_pretrained(output_dir)

# load
loaded_model = AutoModelForCausalLM.from_pretrained(output_dir)
loaded_output = loaded_model(dummy_input)[0]
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."

# use_layer_wise=False
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=False)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=woq_config,
)
woq_output2 = woq_model(dummy_input)[0]
assert torch.equal(woq_output, woq_output2), "use_layer_wise output should be same. Please double check."

def test_loading_autoawq_model(self):
user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model)
tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model)
Expand Down
Loading