Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support directly loading gptq models from huggingface #9391

Merged
merged 7 commits into from
Nov 14, 2023

Conversation

yangw1234
Copy link
Contributor

@yangw1234 yangw1234 commented Nov 9, 2023

Description

Support directly loading gptq models from huggingface.

Many models are published as gptq format in huggingface. It would be nice to load them directly using bigdl-llm.

Install:

BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
pip install optimum==0.14.0

Usage:

from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import GPTQConfig
quantization_config = GPTQConfig(
    bits=4,
    use_exllama=False,
    )

# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path,
                                             load_in_4bit=True, # will load into asym_int4. if using load_in_low_bit, then `load_in_low_bit` must be  "asym_int4"
                                             torch_dtype=torch.float,
                                             trust_remote_code=True,
                                             quantization_config=quantization_config,)

Limitations:

Only works on 4bit and asc_order=False
GPU version is really slow. Investigation needed.

Perf: https://github.com/analytics-zoo/nano/issues/738

invalidInputError(q_config["bits"] == 4,
"Only 4-bit gptq is supported in bigdl-llm.")
invalidInputError(q_config["desc_act"] is False,
"Only desc_act=False is supported in bigdl-llm.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also check group_size, group_size should be a multiple of 64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -89,6 +89,18 @@ def from_pretrained(cls,
optimize_model = kwargs.pop("optimize_model", True)

if load_in_4bit or load_in_low_bit:

if config_dict.get("quantization_config", None) is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add it to Python Doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added and passing quantization_config is no longer required.

@@ -0,0 +1,73 @@
# Llama2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. move to bigdl/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ
  2. # GPTQ
  3. This example shows how to directly run 4-bit GPTQ models using BigDL-LLM on Intel CPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

mp_group=mp_group,
)

device_type = module.qweight.data.device.type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

invalidInputError(False,
(f"group_size must be divisible by "
f"{get_ggml_qk_size(load_in_low_bit)}."))
if user_quantization_config is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want the user to pass user_quantization_config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not letting user to pass user_quantization_config might be a better choice

else:
from transformers import GPTQConfig
user_quantization_config = GPTQConfig(bits=4, use_exllama=False)
kwargs["quantization_config"] = user_quantization_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does save/load low bit work? Do we need to remove quantization_config in save_low_bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save/load low bit works. It seems our load_low_bit will ignore quantization_config, but I remove it anyway.

Copy link
Contributor

@jason-dai jason-dai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

E.g. on Linux,
```bash
# set BigDL-Nano env variables
source bigdl-nano-init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why nano variables? we have bigdl-llm-init now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from existing examples. will change that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems most of our examples still use bigdl-nano-init. How about we leave it here and change them together in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - please open an issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cyita cyita mentioned this pull request Nov 14, 2023
5 tasks
@yangw1234
Copy link
Contributor Author

The failed test is irrelevant and unit tests on arc lack resources to run. I'll merge this PR first to unblock further development.

@yangw1234 yangw1234 merged commit 282b0df into intel-analytics:main Nov 14, 2023
34 of 36 checks passed
pip install bigdl-llm[all] # install bigdl-llm with 'all' option
pip install transformers==4.34.0
BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
pip install optimum==0.14.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.14.0 -》1.14.0

liu-shaojun pushed a commit that referenced this pull request Mar 25, 2024
* Support directly loading GPTQ models from huggingface

* fix style

* fix tests

* change example structure

* address comments

* fix style

* address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants