@@ -20,6 +20,7 @@ def __init__(
2020 load_in_8bit : bool = False ,
2121 load_in_4bit : bool = True ,
2222 bnb_4bit_compute_dtype : str = "float32" ,
23+ bnb_4bit_quant_storage : str = "uint8" ,
2324 bnb_4bit_quant_type : str = "fp4" ,
2425 bnb_4bit_use_double_quant : bool = False ,
2526 llm_int8_enable_fp32_cpu_offload : bool = False ,
@@ -31,17 +32,23 @@ def __init__(
3132 self .load_in_8bit = load_in_8bit
3233 self .load_in_4bit = load_in_4bit
3334 self .bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
35+ self .bnb_4bit_quant_storage = bnb_4bit_quant_storage
3436 self .bnb_4bit_quant_type = bnb_4bit_quant_type
3537 self .bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
3638 self .llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
3739 self .llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
3840 self .llm_int8_skip_modules = llm_int8_skip_modules or []
3941 self .llm_int8_threshold = llm_int8_threshold
4042
43+ if self .bnb_4bit_quant_storage not in ["uint8" ]:
44+ raise ValueError ("Unsupported bnb_4bit_quant_storage: "
45+ f"{ self .bnb_4bit_quant_storage } " )
46+
4147 def __repr__ (self ) -> str :
4248 return (f"BitsAndBytesConfig(load_in_8bit={ self .load_in_8bit } , "
4349 f"load_in_4bit={ self .load_in_4bit } , "
4450 f"bnb_4bit_compute_dtype={ self .bnb_4bit_compute_dtype } , "
51+ f"bnb_4bit_quant_storage={ self .bnb_4bit_quant_storage } , "
4552 f"bnb_4bit_quant_type={ self .bnb_4bit_quant_type } , "
4653 f"llm_int8_skip_modules={ self .llm_int8_skip_modules } )" )
4754
@@ -80,6 +87,9 @@ def get_safe_value(config, keys, default_value=None):
8087 bnb_4bit_compute_dtype = get_safe_value (config ,
8188 ["bnb_4bit_compute_dtype" ],
8289 default_value = "float32" )
90+ bnb_4bit_quant_storage = get_safe_value (config ,
91+ ["bnb_4bit_quant_storage" ],
92+ default_value = "uint8" )
8393 bnb_4bit_quant_type = get_safe_value (config , ["bnb_4bit_quant_type" ],
8494 default_value = "fp4" )
8595 bnb_4bit_use_double_quant = get_safe_value (
@@ -99,6 +109,7 @@ def get_safe_value(config, keys, default_value=None):
99109 load_in_8bit = load_in_8bit ,
100110 load_in_4bit = load_in_4bit ,
101111 bnb_4bit_compute_dtype = bnb_4bit_compute_dtype ,
112+ bnb_4bit_quant_storage = bnb_4bit_quant_storage ,
102113 bnb_4bit_quant_type = bnb_4bit_quant_type ,
103114 bnb_4bit_use_double_quant = bnb_4bit_use_double_quant ,
104115 llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload ,
0 commit comments