@@ -4519,7 +4519,6 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None):
45194519 else :
45204520 algorithm = config ['weight' ]['algorithm' ]
45214521 all_algo .add (algorithm )
4522-
45234522 if 'GPTQ' in all_algo :
45244523 q_model ._model = self .gptq_quantize (q_model ._model , tune_cfg , dataloader )
45254524
@@ -4555,10 +4554,26 @@ def rtn_quantize(self, model, tune_cfg):
45554554
45564555 def gptq_quantize (self , model , tune_cfg , dataloader ):
45574556 logger .debug ("quantizing with the GPTQ algorithm" )
4557+ from .torch_utils .weight_only import gptq_quantize
45584558 if 'gptq_args' in self .recipes :
45594559 percdamp = self .recipes ['gptq_args' ].get ('percdamp' , 0.01 )
4560+ wbits = self .recipes .get ('wbits' , 4 )
4561+ group_size = self .recipes .get ('group_size' , 128 )
4562+ sym = self .recipes .get ('scheme' , False )
4563+ # implementation of gptq
45604564 # GPTQ(model, dataloader, w_bit, group_size, percdamp=0.01)
4561- # TODO: implementation
4565+ weight_config = {
4566+ 'wbits' : wbits ,
4567+ 'group_size' : group_size ,
4568+ 'sym' : sym ,
4569+ 'percdamp' : percdamp
4570+ }
4571+ model = gptq_quantize (
4572+ model ,
4573+ weight_config ,
4574+ dataloader ,
4575+ self .device
4576+ )
45624577 return model
45634578
45644579 def awq_quantize (self , model , tune_cfg , dataloader , calib_func ):
0 commit comments