@@ -20,6 +20,7 @@ def __init__(
20
20
load_in_8bit : bool = False ,
21
21
load_in_4bit : bool = True ,
22
22
bnb_4bit_compute_dtype : str = "float32" ,
23
+ bnb_4bit_quant_storage : str = "uint8" ,
23
24
bnb_4bit_quant_type : str = "fp4" ,
24
25
bnb_4bit_use_double_quant : bool = False ,
25
26
llm_int8_enable_fp32_cpu_offload : bool = False ,
@@ -31,17 +32,23 @@ def __init__(
31
32
self .load_in_8bit = load_in_8bit
32
33
self .load_in_4bit = load_in_4bit
33
34
self .bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
35
+ self .bnb_4bit_quant_storage = bnb_4bit_quant_storage
34
36
self .bnb_4bit_quant_type = bnb_4bit_quant_type
35
37
self .bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
36
38
self .llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
37
39
self .llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
38
40
self .llm_int8_skip_modules = llm_int8_skip_modules or []
39
41
self .llm_int8_threshold = llm_int8_threshold
40
42
43
+ if self .bnb_4bit_quant_storage not in ["uint8" ]:
44
+ raise ValueError ("Unsupported bnb_4bit_quant_storage: "
45
+ f"{ self .bnb_4bit_quant_storage } " )
46
+
41
47
def __repr__ (self ) -> str :
42
48
return (f"BitsAndBytesConfig(load_in_8bit={ self .load_in_8bit } , "
43
49
f"load_in_4bit={ self .load_in_4bit } , "
44
50
f"bnb_4bit_compute_dtype={ self .bnb_4bit_compute_dtype } , "
51
+ f"bnb_4bit_quant_storage={ self .bnb_4bit_quant_storage } , "
45
52
f"bnb_4bit_quant_type={ self .bnb_4bit_quant_type } , "
46
53
f"llm_int8_skip_modules={ self .llm_int8_skip_modules } )" )
47
54
@@ -80,6 +87,9 @@ def get_safe_value(config, keys, default_value=None):
80
87
bnb_4bit_compute_dtype = get_safe_value (config ,
81
88
["bnb_4bit_compute_dtype" ],
82
89
default_value = "float32" )
90
+ bnb_4bit_quant_storage = get_safe_value (config ,
91
+ ["bnb_4bit_quant_storage" ],
92
+ default_value = "uint8" )
83
93
bnb_4bit_quant_type = get_safe_value (config , ["bnb_4bit_quant_type" ],
84
94
default_value = "fp4" )
85
95
bnb_4bit_use_double_quant = get_safe_value (
@@ -99,6 +109,7 @@ def get_safe_value(config, keys, default_value=None):
99
109
load_in_8bit = load_in_8bit ,
100
110
load_in_4bit = load_in_4bit ,
101
111
bnb_4bit_compute_dtype = bnb_4bit_compute_dtype ,
112
+ bnb_4bit_quant_storage = bnb_4bit_quant_storage ,
102
113
bnb_4bit_quant_type = bnb_4bit_quant_type ,
103
114
bnb_4bit_use_double_quant = bnb_4bit_use_double_quant ,
104
115
llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload ,
0 commit comments