Skip to content

Commit

Permalink
[bitsandbytes]: support read bnb pre-quantized model (vllm-project#5753)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
thesues and mgoin authored Jul 23, 2024
1 parent 2f808e6 commit 87525fa
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Documentation

quantization/supported_hardware
quantization/auto_awq
quantization/bnb
quantization/fp8
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache
Expand Down
43 changes: 43 additions & 0 deletions docs/source/quantization/bnb.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
.. _bits_and_bytes:

BitsAndBytes
==================

vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.

Below are the steps to utilize BitsAndBytes with vLLM.

.. code-block:: console
$ pip install bitsandbytes>=0.42.0
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.

You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
And usually, these repositories have a config.json file that includes a quantization_config section.

Read quantized checkpoint.
--------------------------

.. code-block:: python
from vllm import LLM
import torch
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
model_id = "unsloth/tinyllama-bnb-4bit"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
Inflight quantization: load as 4bit quantization
------------------------------------------------

.. code-block:: python
from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
18 changes: 14 additions & 4 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams

models_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
]


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None:
with vllm_runner('huggyllama/llama-7b',
@pytest.mark.parametrize("model_name, description", models_to_test)
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501

# check the weights in MLP & SelfAttention are quantized to torch.uint8
Expand Down Expand Up @@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)

assert len(outputs) == len(prompts)

for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]

assert len(actual_output) >= len(expected_output), (
f'Actual {actual_output} should be larger than or equal to '
f'expected {expected_output}')
actual_output = actual_output[:len(expected_output)]

assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,11 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,8 @@ def create_engine_config(self, ) -> EngineConfig:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
Expand Down
25 changes: 4 additions & 21 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314
"""

def __init__(
self,
adapter_name_or_path: str,
target_modules: List[str],
) -> None:

self.adapter_name_or_path = adapter_name_or_path
self.target_modules = target_modules
def __init__(self, ) -> None:
pass

def __repr__(self) -> str:
return (
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
)
return "BitsAndBytesConfig"

@classmethod
def get_name(self) -> str:
Expand All @@ -49,16 +41,7 @@ def get_config_filenames() -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
if adapter_name == "":
target_modules = default_target_modules
else:
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)
return cls()

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
Expand Down
88 changes: 76 additions & 12 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,14 @@ def _prepare_weights(self, model_name_or_path: str,

return hf_weights_files, matched_pattern == "*.safetensors"

def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)

def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str]
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
Expand All @@ -712,6 +718,7 @@ def _get_quantized_weights_iterator(
# only load the bitsandbytes module when needed
try:
import bitsandbytes
from bitsandbytes.functional import QuantState
if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
Expand All @@ -725,17 +732,63 @@ def _get_quantized_weights_iterator(
model_name_or_path, revision)

quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weight_iterator = pt_weights_iterator(hf_weights_files)

def generator():
def quantized_checkpoint() -> Generator:
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# TODO: only nf4 quantization is supported for now
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
raise NotImplementedError(
"Only bitsandbytes_nf4 quantization"
f"is supported for now. {weight_name} is fp4 quantized"
)
temp_state_dict[weight_name] = weight_tensor

# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state[param_name +
".quant_state.bitsandbytes__nf4"] = quant_state[
param_name +
".quant_state.bitsandbytes__nf4"].cpu().data
return QuantState.from_dict(quant_state, device="cuda")

# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if weight_name + ".quant_state.bitsandbytes__nf4" \
in temp_state_dict:
quant_state = _parse_quant_state(weight_name,
temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight",
".qweight"), weight_tensor
else:
yield weight_name, weight_tensor

def generator() -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# bitsandbytes requires data in GPU
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
Expand All @@ -749,6 +802,8 @@ def generator():

yield weight_name, processed_weight

if pre_quant:
return quantized_checkpoint(), quant_state_dict
return generator(), quant_state_dict

def _load_weights(self, model_config: ModelConfig,
Expand All @@ -766,12 +821,21 @@ def _load_weights(self, model_config: ModelConfig,
logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")

qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision))
is_quantized_checkpoint = False
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if quant_config is not None and quant_config.get(
'quant_method') == "bitsandbytes":
is_quantized_checkpoint = True

qweight_iterator, quant_state_dict = \
self._get_quantized_weights_iterator(
model_config.model, model_config.revision, is_quantized_checkpoint)

model.load_weights(qweight_iterator)

torch.cuda.empty_cache()

param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict:
Expand Down Expand Up @@ -809,9 +873,9 @@ def _load_weights(self, model_config: ModelConfig,
f"pack_factor not set for parameter {param_name}.")

num_elements = [0] * len(quant_states)
for seq, quant_state in enumerate(quant_states.items()):
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(
quant_state[1].shape) // pack_ratio
quant_state.shape) // pack_ratio

offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:

quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
Expand Down

0 comments on commit 87525fa

Please sign in to comment.