@@ -166,7 +166,16 @@ class GPTQuantizer(object):
166166 url: https://arxiv.org/abs/2210.17323
167167 """
168168
169- def __init__ (self , model , weight_config = {}, dataloader = None , nsamples = 128 , use_max_length = True , device = None ):
169+ def __init__ (
170+ self ,
171+ model ,
172+ weight_config = {},
173+ dataloader = None ,
174+ nsamples = 128 ,
175+ use_max_length = True ,
176+ pad_max_length = 2048 ,
177+ device = None ,
178+ ):
170179 """
171180 Args:
172181 model: the fp32 model to quantize
@@ -211,44 +220,29 @@ def __init__(self, model, weight_config={}, dataloader=None, nsamples=128, use_m
211220
212221 # dataloader
213222 self .use_max_length = use_max_length
223+ self .pad_max_length = pad_max_length
214224 self .dataloader_original = dataloader
215225 self .dataloader = []
216226 self .nsamples = nsamples
217227 self .prepare_dataloader ()
218228
219229 def prepare_dataloader (self ):
220230 if self .use_max_length :
221- # (Recommend) only take sequence whose length exceeds model.seqlen ,
231+ # (Recommend) only take sequence whose length exceeds self.pad_max_length ,
222232 # which perserves calibration's tokens are all valid
223233 # This is GPTQ official dataloader implementation
224234 self .obtain_first_n_samples_fulllength ()
225- # initialize buffers which are essential for gptq computation.
226- self .model_hidden_size = 2048
227- self .initialize_inp_buffersize ()
228- try :
229- # Since length is unified, we can allocate a continous space to store inputs
230- self .inp = torch .zeros (
231- (len (self .dataloader ), self .model .seqlen , self .model_hidden_size ),
232- dtype = self .dtype ,
233- device = self .device ,
234- )
235- self .cache = {"i" : 0 }
236- self .out = torch .zeros_like (self .inp )
237- self .is_ready = True
238- except :
239- logger .warning ("GPTQ Quantizer initialization failed!" )
240- pass
241235 else :
242236 # general selection, no padding, not GPTQ original implementation.
243237 self .obtain_first_n_samples ()
244- try :
245- self .inp = [torch .zeros (1 ) for _ in range (len (self .dataloader ))]
246- self .cache = {"i" : 0 }
247- self .out = [torch .zeros (1 ) for _ in range (len (self .dataloader ))]
248- self .is_ready = True
249- except :
250- logger .warning ("GPTQ Quantizer initialization failed!" )
251- pass
238+ try :
239+ self .inp = [torch .zeros (1 ) for _ in range (len (self .dataloader ))]
240+ self .cache = {"i" : 0 }
241+ self .out = [torch .zeros (1 ) for _ in range (len (self .dataloader ))]
242+ self .is_ready = True
243+ except :
244+ logger .warning ("GPTQ Quantizer initialization failed!" )
245+ pass
252246
253247 def obtain_first_n_samples (self , seed = 0 ):
254248 """Get first nsample data as the real calibration dataset."""
@@ -257,12 +251,13 @@ def obtain_first_n_samples(self, seed=0):
257251 for batch in self .dataloader_original :
258252 # process data, depends on its data type.
259253 if len (self .dataloader ) == self .nsamples :
254+ logger .info (f"Successfully collect { self .nsamples } calibration samples." )
260255 break
261256 # list, tuple
262257 if isinstance (batch , list ) or isinstance (batch , tuple ):
263- if batch [0 ].shape [- 1 ] > self .model . seqlen :
264- i = random .randint (0 , batch [0 ].shape [- 1 ] - self .model . seqlen - 1 )
265- j = i + self .model . seqlen
258+ if batch [0 ].shape [- 1 ] > self .pad_max_length :
259+ i = random .randint (0 , batch [0 ].shape [- 1 ] - self .pad_max_length - 1 )
260+ j = i + self .pad_max_length
266261 batch_final = batch [0 ][:, i :j ]
267262 else :
268263 batch_final = batch [0 ]
@@ -274,9 +269,9 @@ def obtain_first_n_samples(self, seed=0):
274269 logger .warning ("Please make sure your dict'like data contains key of 'input_ids'." )
275270 continue
276271 batch_final = {}
277- if length > self .model . seqlen :
278- i = random .randint (0 , length - self .model . seqlen - 1 )
279- j = i + self .model . seqlen
272+ if length > self .pad_max_length :
273+ i = random .randint (0 , length - self .pad_max_length - 1 )
274+ j = i + self .pad_max_length
280275 # may have to slice every sequence related data
281276 for key in batch .keys ():
282277 if isinstance (batch [key ], torch .Tensor ):
@@ -287,9 +282,9 @@ def obtain_first_n_samples(self, seed=0):
287282 batch_final = batch
288283 # tensor
289284 else :
290- if batch .shape [- 1 ] > self .model . seqlen :
291- i = random .randint (0 , batch .shape [- 1 ] - self .model . seqlen - 1 )
292- j = i + self .model . seqlen
285+ if batch .shape [- 1 ] > self .pad_max_length :
286+ i = random .randint (0 , batch .shape [- 1 ] - self .pad_max_length - 1 )
287+ j = i + self .pad_max_length
293288 batch_final = batch [:, i :j ]
294289 else :
295290 batch_final = batch
@@ -301,9 +296,10 @@ def obtain_first_n_samples(self, seed=0):
301296 def obtain_first_n_samples_fulllength (self , seed = 0 ):
302297 self .dataloader .clear ()
303298 random .seed (seed )
304- unified_length = self .model . seqlen
299+ unified_length = self .pad_max_length
305300 for batch in self .dataloader_original :
306301 if len (self .dataloader ) == self .nsamples :
302+ logger .info (f"Successfully collect { self .nsamples } calibration samples." )
307303 break
308304 # list & tuple
309305 if isinstance (batch , list ) or isinstance (batch , tuple ):
@@ -325,11 +321,11 @@ def obtain_first_n_samples_fulllength(self, seed=0):
325321 logger .warning ("Please make sure your dict'like data contains key of 'input_ids'." )
326322 continue
327323 batch_final = {}
328- if length == self .model . seqlen :
324+ if length == self .pad_max_length :
329325 batch_final = batch
330- elif length > self .model . seqlen :
331- i = random .randint (0 , length - self .model . seqlen - 1 )
332- j = i + self .model . seqlen
326+ elif length > self .pad_max_length :
327+ i = random .randint (0 , length - self .pad_max_length - 1 )
328+ j = i + self .pad_max_length
333329 # may have to slice every sequence related data
334330 for key in batch .keys ():
335331 if isinstance (batch [key ], torch .Tensor ):
@@ -354,53 +350,9 @@ def obtain_first_n_samples_fulllength(self, seed=0):
354350 if len (self .dataloader ) < self .nsamples : # pragma: no cover
355351 logger .warning (
356352 f"Trying to allocate { self .nsamples } data with fixed length { unified_length } , \
357- but only { len (self .dataloader )} samples satisfy your setting. You may choose smaller 'model.seqlen ' value."
353+ but only { len (self .dataloader )} samples are found. Please use smaller 'self.pad_max_length ' value."
358354 )
359355
360- @torch .no_grad ()
361- def initialize_inp_buffersize (self ):
362- # Run a forward and generate proper buffer tensor
363- # Thus, no need to pass hidden_states dimension parameters of model.config
364- # e.g. OPT's hidden_states dimension can be called by model.config.hidden_size
365- # but mpt's hidden_states dimension can be called by model.config.d_model
366- def forward (layer , hidden_states , ** kwargs ):
367- # inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1
368- logger .info (f"The hidden_states shape along transformers blocks is { hidden_states .shape } ." )
369- self .model_hidden_size = hidden_states .shape [- 1 ]
370- raise ValueError
371-
372- # Step1: fetch the embeddings and other layers before the transformer stack.
373- for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
374- embedding_layer = embedding_layer .to (self .device )
375-
376- # Step2: modify the first transformer block's forward function to obtain inputs for calibration
377- self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].to (self .device )
378- forward_cache = self .gptq_related_blocks ["transformers" ][0 ].forward
379- self .gptq_related_blocks ["transformers" ][0 ].forward = partial (
380- forward , self .gptq_related_blocks ["transformers" ][0 ]
381- )
382-
383- # Step3: run forward to obtain calibration datasets
384- logger .info ("Collecting calibration inputs..." )
385- for batch in self .dataloader :
386- batch = move_input_to_device (batch , self .device )
387- try :
388- if isinstance (batch , tuple ) or isinstance (batch , list ):
389- self .model (batch [0 ])
390- elif isinstance (batch , dict ):
391- self .model (** batch )
392- else :
393- self .model (batch .to (self .device ))
394- except ValueError :
395- break
396-
397- # Step 4: restore original forward function, relocate layers back to cpu.
398- self .gptq_related_blocks ["transformers" ][0 ].forward = forward_cache
399- self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].cpu ()
400- for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
401- embedding_layer .to (self .device )
402- torch .cuda .empty_cache ()
403-
404356 def get_full_layer_name (self , sub_layer_name , block_idx ):
405357 transformer_name = self .gptq_related_blocks ["transformers_name" ]
406358 return "." .join ([transformer_name , str (block_idx ), sub_layer_name ])
@@ -459,18 +411,12 @@ def forward(layer, hidden_states, **kwargs):
459411 self .cache ["i" ] += 1
460412 for arg in kwargs :
461413 # TODO: investigate include parameters
462- if self .use_max_length :
463- if isinstance (kwargs [arg ], torch .Tensor ) or arg == "alibi" :
464- self .cache [arg ] = kwargs [arg ]
465- else :
466- continue
467- else :
468- # each outputs can be different shape, hence also use list to store
469- if isinstance (kwargs [arg ], torch .Tensor ) or arg == "alibi" :
470- if self .cache .get (arg , None ) is None :
471- self .cache [arg ] = []
472- self .cache [arg ].append (kwargs [arg ])
473- continue
414+ # each outputs can be different shape, hence also use list to store
415+ if isinstance (kwargs [arg ], torch .Tensor ) or arg == "alibi" :
416+ if self .cache .get (arg , None ) is None :
417+ self .cache [arg ] = []
418+ self .cache [arg ].append (kwargs [arg ])
419+ continue
474420 raise ValueError
475421
476422 # Step1: fetch the embeddings and other layers before the transformer stack.
@@ -572,13 +518,9 @@ def tmp(_, inp, out):
572518 handles .append (sub_layers [layer_name ].register_forward_hook (add_batch (layer_name )))
573519 idx = self .cache .pop ("i" )
574520 for j in range (len (self .dataloader )):
575- if self .use_max_length :
576- # self.inp[j] shape: [seq_len, hidden_size]
577- self .out [j ] = transformer_block (self .inp [j ].unsqueeze (0 ), ** self .cache )[0 ]
578- else :
579- # self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
580- cache_batch = self .gather_single_batch_from_dict (self .cache , j )
581- self .out [j ] = transformer_block (self .inp [j ], ** cache_batch )[0 ]
521+ # self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
522+ cache_batch = self .gather_single_batch_from_dict (self .cache , j )
523+ self .out [j ] = transformer_block (self .inp [j ], ** cache_batch )[0 ]
582524 self .cache ["i" ] = idx
583525 for h in handles :
584526 h .remove ()
@@ -607,13 +549,9 @@ def tmp(_, inp, out):
607549 # Step 2.5: replace output data with quantized weights
608550 idx = self .cache .pop ("i" )
609551 for j in range (len (self .dataloader )):
610- if self .use_max_length :
611- # self.inp[j] shape: [seq_len, hidden_size]
612- self .out [j ] = transformer_block (self .inp [j ].unsqueeze (0 ), ** self .cache )[0 ]
613- else :
614- # self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
615- cache_batch = self .gather_single_batch_from_dict (self .cache , j )
616- self .out [j ] = transformer_block (self .inp [j ], ** cache_batch )[0 ]
552+ # self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
553+ cache_batch = self .gather_single_batch_from_dict (self .cache , j )
554+ self .out [j ] = transformer_block (self .inp [j ], ** cache_batch )[0 ]
617555 self .cache ["i" ] = idx
618556 self .gptq_related_blocks ["transformers" ][block_idx ] = transformer_block .cpu ()
619557 del gptq_for_this_block
0 commit comments