Skip to content

Commit d2f995b

Browse files
authored
support TEQ quantization method (#1093)
1 parent 59172ad commit d2f995b

File tree

7 files changed

+645
-4
lines changed

7 files changed

+645
-4
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4522,6 +4522,9 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None):
45224522
if 'GPTQ' in all_algo:
45234523
q_model._model = self.gptq_quantize(q_model._model, tune_cfg, dataloader)
45244524

4525+
if 'TEQ' in all_algo:
4526+
q_model._model = self.teq_quantize(q_model._model, tune_cfg, dataloader, calib_func)
4527+
45254528
if 'AWQ' in all_algo: # includes RTN in AWQ
45264529
q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader, calib_func)
45274530
elif 'RTN' in all_algo:
@@ -4582,6 +4585,43 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
45824585
)
45834586
return model
45844587

4588+
def teq_quantize(self, model, tune_cfg, dataloader, calib_func):
4589+
logger.debug("quantizing with the TEQ algorithm")
4590+
from .torch_utils.weight_only import teq_quantize
4591+
# get example inputs if not provided.
4592+
if self.example_inputs is None:
4593+
if dataloader is None:
4594+
assert False, "Please provide dataloader or example_inputs for TEQ algorithm."
4595+
try:
4596+
for idx, (input, label) in enumerate(dataloader):
4597+
self.example_inputs = input
4598+
break
4599+
except:
4600+
for idx, input in enumerate(dataloader):
4601+
self.example_inputs = input
4602+
break
4603+
4604+
if 'teq_args' in self.recipes:
4605+
wbits = self.recipes.get('wbits', 4)
4606+
group_size = self.recipes.get('group_size', 128)
4607+
sym = self.recipes.get('scheme', False)
4608+
folding = self.recipes.get('folding', True)
4609+
4610+
weight_config = {
4611+
'wbits': wbits,
4612+
'group_size': group_size,
4613+
'sym': sym,
4614+
'folding': folding
4615+
}
4616+
quantizer = teq_quantize(
4617+
model,
4618+
weight_config,
4619+
dataloader,
4620+
example_inputs=self.example_inputs,
4621+
calib_func=calib_func
4622+
)
4623+
return quantizer.model
4624+
45854625
def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
45864626
logger.debug("quantizing with the AWQ algorithm")
45874627
from .torch_utils.weight_only import awq_quantize

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import math
2222
import torch
2323
from torch.nn import functional as F
24+
from torch.autograd import Function
25+
from .weight_only import quant_weight
2426
from packaging.version import Version
2527

2628

@@ -355,3 +357,95 @@ def extra_repr(self) -> str:
355357
return 'in_features={}, out_features={}, bits={}, group_size={}, bias={}'.format(
356358
self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None
357359
)
360+
361+
362+
class FakeAffineTensorQuantFunction(Function):
363+
"""Fake version of affine quantization
364+
"""
365+
366+
@staticmethod
367+
def forward(ctx, inputs, num_bits=4, group_size=1024):
368+
"""
369+
370+
As it will be only applied on activation with per tensor granularity, broadcast is not needed.
371+
372+
Args:
373+
ctx: Pytorch convention.
374+
inputs: A Tensor of type float32.
375+
min_range: A float.
376+
max_range: A float.
377+
num_bits: An integer
378+
379+
Returns:
380+
outputs: A Tensor of type output_dtype
381+
"""
382+
return quant_weight(inputs, num_bits, group_size)
383+
384+
@staticmethod
385+
def backward(ctx, grad_outputs):
386+
"""
387+
Args:
388+
ctx: Pytorch convention.
389+
grad_output: A tensor of gradient of outputs
390+
391+
Returns:
392+
grad_inputs: A tensor of gradient
393+
"""
394+
return grad_outputs, None, None
395+
396+
397+
class TEQLinearFakeQuant(torch.nn.Module):
398+
"""
399+
wrapper quantization linear
400+
"""
401+
402+
def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1):
403+
"""
404+
A forward hook to linear module
405+
:param orig_layer: the original module
406+
:param alpha: trainable alpha/scale
407+
:param num_bits: quantization level
408+
:param group_size: for fine-grained quantization
409+
"""
410+
super(TEQLinearFakeQuant, self).__init__()
411+
self.orig_layer = orig_layer
412+
self.alpha = alpha
413+
414+
self.num_bits = num_bits
415+
self.group_size = group_size
416+
417+
def forward(self, x):
418+
alpha = torch.clip(self.alpha, 1e-5)
419+
shape_len = len(x.shape) - 1
420+
shape = (1,) * shape_len + (-1,)
421+
x = x / alpha.view(shape)
422+
weight = self.orig_layer.weight
423+
weight = weight * alpha.unsqueeze(dim=0)
424+
weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, self.group_size)
425+
return F.linear(x, weight_q, self.orig_layer.bias)
426+
427+
428+
class TEQMulLinear(torch.nn.Module):
429+
"""
430+
Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input
431+
"""
432+
433+
def __init__(self, module, input_scale):
434+
"""
435+
A forward hook to save input max of a module
436+
:param module: the linear module
437+
:param input_scale: scale for input
438+
"""
439+
440+
super().__init__()
441+
self.register_buffer('input_scale', input_scale)
442+
self.add_module('sq_linear', module)
443+
444+
@property
445+
def weight(self):
446+
return self.sq_linear.weight
447+
448+
def forward(self, X):
449+
X = torch.mul(X, self.input_scale)
450+
X = self.sq_linear(X)
451+
return X

0 commit comments

Comments
 (0)