Skip to content

Commit 501bb0c

Browse files
thesuesmgoin
authored andcommitted
[bitsandbytes]: support read bnb pre-quantized model (vllm-project#5753)
Co-authored-by: Michael Goin <michael@neuralmagic.com> Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent a3a5dcb commit 501bb0c

File tree

8 files changed

+143
-39
lines changed

8 files changed

+143
-39
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Documentation
105105

106106
quantization/supported_hardware
107107
quantization/auto_awq
108+
quantization/bnb
108109
quantization/fp8
109110
quantization/fp8_e5m2_kvcache
110111
quantization/fp8_e4m3_kvcache

docs/source/quantization/bnb.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
.. _bits_and_bytes:
2+
3+
BitsAndBytes
4+
==================
5+
6+
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
7+
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
8+
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
9+
10+
Below are the steps to utilize BitsAndBytes with vLLM.
11+
12+
.. code-block:: console
13+
14+
$ pip install bitsandbytes>=0.42.0
15+
16+
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
17+
18+
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
19+
And usually, these repositories have a config.json file that includes a quantization_config section.
20+
21+
Read quantized checkpoint.
22+
--------------------------
23+
24+
.. code-block:: python
25+
26+
from vllm import LLM
27+
import torch
28+
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
29+
model_id = "unsloth/tinyllama-bnb-4bit"
30+
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
31+
quantization="bitsandbytes", load_format="bitsandbytes")
32+
33+
Inflight quantization: load as 4bit quantization
34+
------------------------------------------------
35+
36+
.. code-block:: python
37+
38+
from vllm import LLM
39+
import torch
40+
model_id = "huggyllama/llama-7b"
41+
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
42+
quantization="bitsandbytes", load_format="bitsandbytes")
43+

tests/quantization/test_bitsandbytes.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@
88
from tests.quantization.utils import is_quant_method_supported
99
from vllm import SamplingParams
1010

11+
models_to_test = [
12+
('huggyllama/llama-7b', 'quantize model inflight'),
13+
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
14+
]
15+
1116

1217
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
1318
reason='bitsandbytes is not supported on this GPU type.')
14-
def test_load_bnb_model(vllm_runner) -> None:
15-
with vllm_runner('huggyllama/llama-7b',
19+
@pytest.mark.parametrize("model_name, description", models_to_test)
20+
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
21+
with vllm_runner(model_name,
1622
quantization='bitsandbytes',
1723
load_format='bitsandbytes',
1824
enforce_eager=True) as llm:
19-
2025
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2126

2227
# check the weights in MLP & SelfAttention are quantized to torch.uint8
@@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
6570
'To be or not to be, that is the question.'
6671
]
6772
outputs = llm.generate(prompts, sampling_params=sampling_params)
68-
6973
assert len(outputs) == len(prompts)
7074

7175
for index in range(len(outputs)):
7276
# compare the first line of the output
7377
actual_output = outputs[index][1][0].split('\n', 1)[0]
7478
expected_output = expected_outputs[index].split('\n', 1)[0]
79+
80+
assert len(actual_output) >= len(expected_output), (
81+
f'Actual {actual_output} should be larger than or equal to '
82+
f'expected {expected_output}')
83+
actual_output = actual_output[:len(expected_output)]
84+
7585
assert actual_output == expected_output, (
7686
f'Expected: {expected_output}, but got: {actual_output}')

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,11 @@ class LoadConfig:
591591
mainly for profiling.
592592
"tensorizer" will use CoreWeave's tensorizer library for
593593
fast weight loading.
594+
"bitsandbytes" will load nf4 type weights.
594595
ignore_patterns: The list of patterns to ignore when loading the model.
595596
Default to "original/**/*" to avoid repeated loading of llama's
596597
checkpoints.
598+
597599
"""
598600

599601
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO

vllm/engine/arg_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,8 @@ def create_engine_config(self, ) -> EngineConfig:
676676
# bitsandbytes quantization needs a specific model loader
677677
# so we make sure the quant method and the load format are consistent
678678
if (self.quantization == "bitsandbytes" or
679-
self.qlora_adapter_name_or_path is not None) and \
680-
self.load_format != "bitsandbytes":
679+
self.qlora_adapter_name_or_path is not None) and \
680+
self.load_format != "bitsandbytes":
681681
raise ValueError(
682682
"BitsAndBytes quantization and QLoRA adapter only support "
683683
f"'bitsandbytes' load format, but got {self.load_format}")

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
1515
Reference: https://arxiv.org/abs/2305.14314
1616
"""
1717

18-
def __init__(
19-
self,
20-
adapter_name_or_path: str,
21-
target_modules: List[str],
22-
) -> None:
23-
24-
self.adapter_name_or_path = adapter_name_or_path
25-
self.target_modules = target_modules
18+
def __init__(self, ) -> None:
19+
pass
2620

2721
def __repr__(self) -> str:
28-
return (
29-
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
30-
)
22+
return "BitsAndBytesConfig"
3123

3224
@classmethod
3325
def get_name(self) -> str:
@@ -49,16 +41,7 @@ def get_config_filenames() -> List[str]:
4941

5042
@classmethod
5143
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
52-
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
53-
default_target_modules = [
54-
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
55-
"o_proj"
56-
]
57-
if adapter_name == "":
58-
target_modules = default_target_modules
59-
else:
60-
target_modules = cls.get_from_keys(config, ["target_modules"])
61-
return cls(adapter_name, target_modules)
44+
return cls()
6245

6346
def get_quant_method(self, layer: torch.nn.Module,
6447
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:

vllm/model_executor/model_loader/loader.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,14 @@ def _prepare_weights(self, model_name_or_path: str,
702702

703703
return hf_weights_files, matched_pattern == "*.safetensors"
704704

705+
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
706+
if use_safetensors:
707+
return safetensors_weights_iterator(hf_weights_files)
708+
else:
709+
return pt_weights_iterator(hf_weights_files)
710+
705711
def _get_quantized_weights_iterator(
706-
self, model_name_or_path: str, revision: Optional[str]
712+
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
707713
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
708714
Any]]:
709715
"""Get an iterator to the model weights with bitsandbytes quantization,
@@ -712,6 +718,7 @@ def _get_quantized_weights_iterator(
712718
# only load the bitsandbytes module when needed
713719
try:
714720
import bitsandbytes
721+
from bitsandbytes.functional import QuantState
715722
if bitsandbytes.__version__ < "0.42.0":
716723
raise ImportError("bitsandbytes version is wrong. Please "
717724
"install bitsandbytes>=0.42.0.")
@@ -725,17 +732,63 @@ def _get_quantized_weights_iterator(
725732
model_name_or_path, revision)
726733

727734
quant_state_dict = {}
728-
if use_safetensors:
729-
weight_iterator = safetensors_weights_iterator(hf_weights_files)
730-
else:
731-
weight_iterator = pt_weights_iterator(hf_weights_files)
732735

733-
def generator():
736+
def quantized_checkpoint() -> Generator:
737+
# First iterate over all quant state weights
738+
weight_iterator = self._hf_weight_iter(hf_weights_files,
739+
use_safetensors)
740+
temp_state_dict = {}
734741
for weight_name, weight_tensor in weight_iterator:
742+
if weight_name.endswith(".weight"):
743+
continue
744+
# TODO: only nf4 quantization is supported for now
745+
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
746+
raise NotImplementedError(
747+
"Only bitsandbytes_nf4 quantization"
748+
f"is supported for now. {weight_name} is fp4 quantized"
749+
)
750+
temp_state_dict[weight_name] = weight_tensor
751+
752+
# Closure to parse quant_state for each prequant weight
753+
def _parse_quant_state(param_name: str,
754+
temp_state_dict: Dict) -> QuantState:
755+
quant_state = {}
756+
for k in temp_state_dict:
757+
if param_name + "." in k:
758+
quant_state[k] = temp_state_dict[k]
759+
# bitsandbytes library requires
760+
# weight.quant_state.bitsandbytes__nf4 in CPU
761+
quant_state[param_name +
762+
".quant_state.bitsandbytes__nf4"] = quant_state[
763+
param_name +
764+
".quant_state.bitsandbytes__nf4"].cpu().data
765+
return QuantState.from_dict(quant_state, device="cuda")
766+
767+
# Second iterate over all prequant and normal weights
768+
# pre quantized weights would have a quant_state
769+
for weight_name, weight_tensor in self._hf_weight_iter(
770+
hf_weights_files, use_safetensors):
771+
# Filter out all weights whose suffix is not ".weight"
772+
if not weight_name.endswith(".weight"):
773+
continue
774+
if weight_name + ".quant_state.bitsandbytes__nf4" \
775+
in temp_state_dict:
776+
quant_state = _parse_quant_state(weight_name,
777+
temp_state_dict)
778+
weight_name = weight_name.replace(".weight", ".qweight")
779+
quant_state_dict[weight_name] = quant_state
780+
yield weight_name.replace(".weight",
781+
".qweight"), weight_tensor
782+
else:
783+
yield weight_name, weight_tensor
784+
785+
def generator() -> Generator:
786+
for weight_name, weight_tensor in self._hf_weight_iter(
787+
hf_weights_files, use_safetensors):
735788
if any(target_module in weight_name
736789
for target_module in self.target_modules):
737790
weight_name = weight_name.replace(".weight", ".qweight")
738-
# bitsandbytes requires data in GPU
791+
# bitsandbytes requires data in GPU
739792
loaded_weight = weight_tensor.cuda().data
740793
with set_default_torch_dtype(torch.float32):
741794
processed_weight, quant_state = quantize_4bit(
@@ -749,6 +802,8 @@ def generator():
749802

750803
yield weight_name, processed_weight
751804

805+
if pre_quant:
806+
return quantized_checkpoint(), quant_state_dict
752807
return generator(), quant_state_dict
753808

754809
def _load_weights(self, model_config: ModelConfig,
@@ -766,12 +821,21 @@ def _load_weights(self, model_config: ModelConfig,
766821
logger.info("Loading weights with BitsAndBytes quantization. "
767822
" May take a while ...")
768823

769-
qweight_iterator, quant_state_dict = (
770-
self._get_quantized_weights_iterator(model_config.model,
771-
model_config.revision))
824+
is_quantized_checkpoint = False
825+
quant_config = getattr(model_config.hf_config, "quantization_config",
826+
None)
827+
if quant_config is not None and quant_config.get(
828+
'quant_method') == "bitsandbytes":
829+
is_quantized_checkpoint = True
830+
831+
qweight_iterator, quant_state_dict = \
832+
self._get_quantized_weights_iterator(
833+
model_config.model, model_config.revision, is_quantized_checkpoint)
772834

773835
model.load_weights(qweight_iterator)
774836

837+
torch.cuda.empty_cache()
838+
775839
param_dict = dict(model.named_parameters())
776840
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
777841
for quant_param_name in quant_state_dict:
@@ -809,9 +873,9 @@ def _load_weights(self, model_config: ModelConfig,
809873
f"pack_factor not set for parameter {param_name}.")
810874

811875
num_elements = [0] * len(quant_states)
812-
for seq, quant_state in enumerate(quant_states.items()):
876+
for seq, quant_state in quant_states.items():
813877
num_elements[seq] = math.prod(
814-
quant_state[1].shape) // pack_ratio
878+
quant_state.shape) // pack_ratio
815879

816880
offsets = np.concatenate(([0], np.cumsum(num_elements)))
817881
set_weight_attrs(param, {"bnb_shard_offsets": offsets})

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
118118
# TODO(woosuk): Move this to other place.
119119
def get_quant_config(model_config: ModelConfig,
120120
load_config: LoadConfig) -> QuantizationConfig:
121+
121122
quant_cls = get_quantization_config(model_config.quantization)
122123
# Read the quantization config from the HF model config, if available.
123124
hf_quant_config = getattr(model_config.hf_config, "quantization_config",

0 commit comments

Comments
 (0)