2222from auto_round .low_cpu_mem .utils import get_layers_before_block
2323from auto_round .mllm .mllm_dataset import get_mllm_dataloader
2424from auto_round .mllm .template import Template , get_template
25+ from auto_round .schemes import QuantizationScheme
2526from auto_round .special_model_handler import (
2627 NOT_SUPPORT_ONLY_TEXT_MODELS ,
2728 SUPPORT_ONLY_TEXT_MODELS ,
@@ -126,61 +127,56 @@ class AutoRoundMLLM(AutoRound):
126127
127128 """
128129
130+ bits : int | None
131+ group_size : int | None
132+ sym : bool | None
133+ data_type : str | None
134+ act_bits : int | None
135+ act_group_size : int | None
136+ act_sym : bool | None
137+ act_data_type : str | None
138+ act_dynamic : bool | None
139+ super_bits : int | None
140+ super_group_size : int | None
141+
129142 def __init__ (
130143 self ,
131144 model : Union [torch .nn .Module , str ],
132145 tokenizer = None ,
133146 processor = None ,
134147 image_processor = None ,
135- bits : int = 4 ,
136- group_size : int = 128 ,
137- sym : bool = True ,
138- layer_config : dict = None ,
139- batch_size : int = 8 ,
140- amp : bool = True ,
141- device : Union [str , torch .device , int ] = 0 ,
142- lr_scheduler = None ,
143- dataset : Union [str , list , tuple , torch .utils .data .DataLoader ] = None ,
144- extra_data_dir : str = None ,
145- template : Union [str , Template ] = None ,
148+ scheme : Union [str , dict , QuantizationScheme ] = "W4A16" ,
149+ layer_config : dict [str , Union [str , dict , QuantizationScheme ]] = None ,
150+ dataset : Union [str , list , tuple , torch .utils .data .DataLoader ] = "NeelNanda/pile-10k" ,
146151 quant_nontext_module : bool = False ,
147- enable_quanted_input : bool = True ,
148- enable_minmax_tuning : bool = True ,
149- lr : float = None ,
150- minmax_lr : float = None ,
151- low_gpu_mem_usage : bool = False ,
152- low_cpu_mem_usage : bool = False ,
153152 iters : int = 200 ,
154- seqlen : int = None ,
153+ seqlen : int = 2048 ,
155154 nsamples : int = 128 ,
156- sampler : str = "rand" ,
157- seed : int = 42 ,
158- nblocks : int = 1 ,
155+ batch_size : int = 8 ,
159156 gradient_accumulate_steps : int = 1 ,
160- not_use_best_mse : bool = False ,
161- dynamic_max_gap : int = - 1 ,
162- data_type : str = "int" ,
163- scale_dtype : str = "fp16" ,
164- act_bits : int = 32 ,
165- act_group_size : int = None ,
166- act_sym : bool = None ,
167- act_dynamic : bool = True ,
168- to_quant_block_names : Union [str , list ] = None ,
169- enable_norm_bias_tuning : bool = False ,
170- truncation : bool = None ,
157+ low_gpu_mem_usage : bool = False ,
158+ device_map : Union [str , torch .device , int , dict ] = 0 ,
171159 enable_torch_compile : bool = False ,
172- model_kwargs : dict = None ,
160+ seed : int = 42 ,
173161 ** kwargs ,
174162 ):
163+ extra_data_dir = kwargs .pop ("extra_data_dir" , None )
164+ template = kwargs .pop ("template" , None )
165+
166+ to_quant_block_names : Union [str , list , None ] = kwargs .pop ("to_quant_block_names" , None )
167+ if device_map is None :
168+ device_map = 0
169+ self ._set_device (device_map )
170+
175171 if isinstance (model , str ):
176- model , processor , tokenizer , image_processor = mllm_load_model (model , device = device )
172+ model , processor , tokenizer , image_processor = mllm_load_model (model , device = self . device )
177173
174+ self .model = model
178175 quant_nontext_module = self ._check_quant_nontext (layer_config , quant_nontext_module )
179176 all_blocks = get_block_names (model , quant_nontext_module )
180177 self .quant_block_list = find_matching_blocks (model , all_blocks , to_quant_block_names )
181178 if to_quant_block_names is None :
182179 to_quant_block_names = extract_block_names_to_str (self .quant_block_list )
183- self .to_quant_block_names = to_quant_block_names
184180 self .extra_data_dir = extra_data_dir
185181 self .quant_nontext_module = quant_nontext_module
186182 self .processor = processor
@@ -219,7 +215,7 @@ def __init__(
219215 " switching to liuhaotian/llava_conv_58k"
220216 )
221217 dataset = "liuhaotian/llava_conv_58k"
222- elif not _only_text_test (model , tokenizer , device , self .template .model_type ):
218+ elif not _only_text_test (model , tokenizer , self . device , self .template .model_type ):
223219 logger .warning (
224220 f"{ model .config .model_type } does not support for { dataset } ,"
225221 " will use liuhaotian/llava_conv_58k with default config as an alternative."
@@ -248,7 +244,7 @@ def __init__(
248244 gradient_accumulate_steps = batch_size * gradient_accumulate_steps
249245 batch_size = 1
250246 seqlen = 2048 if seqlen is None else seqlen
251- truncation = True if truncation is None else truncation
247+ truncation = True
252248 self .truncation = truncation
253249
254250 if nsamples % batch_size != 0 :
@@ -258,40 +254,20 @@ def __init__(
258254 super (AutoRoundMLLM , self ).__init__ (
259255 model = model ,
260256 tokenizer = tokenizer ,
261- bits = bits ,
262- group_size = group_size ,
263- sym = sym ,
257+ scheme = scheme ,
264258 layer_config = layer_config ,
265- batch_size = batch_size ,
266- amp = amp ,
267- device = device ,
268- lr_scheduler = lr_scheduler ,
269259 dataset = dataset ,
270- enable_quanted_input = enable_quanted_input ,
271- enable_minmax_tuning = enable_minmax_tuning ,
272- lr = lr ,
273- minmax_lr = minmax_lr ,
274- low_gpu_mem_usage = low_gpu_mem_usage ,
275- low_cpu_mem_usage = low_cpu_mem_usage ,
276260 iters = iters ,
277261 seqlen = seqlen ,
278262 nsamples = nsamples ,
279- sampler = sampler ,
280- seed = seed ,
281- nblocks = nblocks ,
263+ batch_size = batch_size ,
282264 gradient_accumulate_steps = gradient_accumulate_steps ,
283- not_use_best_mse = not_use_best_mse ,
284- dynamic_max_gap = dynamic_max_gap ,
285- data_type = data_type ,
286- scale_dtype = scale_dtype ,
287- act_bits = act_bits ,
288- act_group_size = act_group_size ,
289- act_sym = act_sym ,
290- act_dynamic = act_dynamic ,
291- to_quant_block_names = self .to_quant_block_names ,
292- enable_norm_bias_tuning = enable_norm_bias_tuning ,
265+ low_gpu_mem_usage = low_gpu_mem_usage ,
266+ device_map = device_map ,
293267 enable_torch_compile = enable_torch_compile ,
268+ seed = seed ,
294269 vlm = True ,
270+ to_quant_block_names = to_quant_block_names ,
295271 ** kwargs ,
296272 )
297273
0 commit comments