Skip to content

Commit bed4cb4

Browse files
authored
Add support for dynamic quant in GPTQ (#80)
Summary: These changes were missed in the beginning, add them back Test Plan: there is some problem with testing gptq quantizer locally, will test in executorch instead Reviewers: Subscribers: Tasks: Tags:
1 parent f44089a commit bed4cb4

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

torchao/quantization/GPTQ.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def configure_quantization_mode(
348348
combine_qparams_list_func,
349349
make_names_and_values_dict_func,
350350
skip_layer_func,
351+
dyn_quant_func = None,
351352
):
352353
# these functions need to already be curried with all inputs other than weight, qparams
353354

@@ -371,6 +372,8 @@ def configure_quantization_mode(
371372
# `make_names_and_values_dict_func`.
372373
self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict
373374
# note any final packing for storage should happen here
375+
376+
self.dyn_quant_func = dyn_quant_func
374377
return self
375378

376379
def run(self):
@@ -451,6 +454,8 @@ def tensors_to_cuda(args):
451454
quantize_linear
452455
): # calculate H instead of output (will run the linear eventually with updated weight)
453456
x = cur_args[0].float()
457+
if self.dyn_quant_func is not None:
458+
x = self.dyn_quant_func(x)
454459
shape = x.shape
455460
n = 1 if len(shape) == 2 else shape[0]
456461
H *= total_batches / (total_batches + n)

torchao/quantization/quant_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,13 @@ class GPTQQuantizer(Quantizer):
283283
Returns:
284284
weight: A 2d weight tensor with non-integer dtype.
285285
286+
dyn_quant_func (optional):
287+
A function that dynamically quantizes inputs
288+
Args:
289+
input: input Tensor in f32/bf16/f16
290+
Returns:
291+
output: dynamically quantized and dequantized Tensor (with the same dtype as input)
292+
286293
combine_qparams_list_func:
287294
A function that combines several qparams into one qparam.
288295
Args:
@@ -397,6 +404,7 @@ def _create_quantized_state_dict(
397404
self.combine_qparams_list_func, # pyre-ignore[16]
398405
self.make_names_and_values_dict_func, # pyre-ignore[16]
399406
self.skip_layer_func, # pyre-ignore[16]
407+
self.dyn_quant_func if hasattr(self, "dyn_quant_func") else None, # pyre-ignore[16]
400408
)
401409
print("Applying GPTQ to weights")
402410
GPTQ_runner.run()
@@ -747,7 +755,7 @@ def __init__(
747755

748756
self.precision = precision
749757

750-
self.dyn_quant_func = lambda x: per_token_dynamic_quant(x)
758+
self.dyn_quant_func = per_token_dynamic_quant
751759
n_bit = 4
752760

753761
self.get_qparams_func = lambda w: get_group_qparams_symmetric(

0 commit comments

Comments
 (0)