Skip to content

Commit 7576cd3

Browse files
authored
[Bugfix] Check bnb_4bit_quant_storage for bitsandbytes (vllm-project#10642)
1 parent 9a99273 commit 7576cd3

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
load_in_8bit: bool = False,
2121
load_in_4bit: bool = True,
2222
bnb_4bit_compute_dtype: str = "float32",
23+
bnb_4bit_quant_storage: str = "uint8",
2324
bnb_4bit_quant_type: str = "fp4",
2425
bnb_4bit_use_double_quant: bool = False,
2526
llm_int8_enable_fp32_cpu_offload: bool = False,
@@ -31,17 +32,23 @@ def __init__(
3132
self.load_in_8bit = load_in_8bit
3233
self.load_in_4bit = load_in_4bit
3334
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
35+
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
3436
self.bnb_4bit_quant_type = bnb_4bit_quant_type
3537
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
3638
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
3739
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
3840
self.llm_int8_skip_modules = llm_int8_skip_modules or []
3941
self.llm_int8_threshold = llm_int8_threshold
4042

43+
if self.bnb_4bit_quant_storage not in ["uint8"]:
44+
raise ValueError("Unsupported bnb_4bit_quant_storage: "
45+
f"{self.bnb_4bit_quant_storage}")
46+
4147
def __repr__(self) -> str:
4248
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
4349
f"load_in_4bit={self.load_in_4bit}, "
4450
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
51+
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
4552
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
4653
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
4754

@@ -80,6 +87,9 @@ def get_safe_value(config, keys, default_value=None):
8087
bnb_4bit_compute_dtype = get_safe_value(config,
8188
["bnb_4bit_compute_dtype"],
8289
default_value="float32")
90+
bnb_4bit_quant_storage = get_safe_value(config,
91+
["bnb_4bit_quant_storage"],
92+
default_value="uint8")
8393
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
8494
default_value="fp4")
8595
bnb_4bit_use_double_quant = get_safe_value(
@@ -99,6 +109,7 @@ def get_safe_value(config, keys, default_value=None):
99109
load_in_8bit=load_in_8bit,
100110
load_in_4bit=load_in_4bit,
101111
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
112+
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
102113
bnb_4bit_quant_type=bnb_4bit_quant_type,
103114
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
104115
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,

0 commit comments

Comments
 (0)