Skip to content

Commit 0464b30

Browse files
committed
deduplicate code for get_group_qparams_symmetric
Summary: This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent f05c215 commit 0464b30

File tree

4 files changed

+81
-112
lines changed

4 files changed

+81
-112
lines changed

test/integration/test_integration.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
quant_int8_dynamic_per_token_linear,
3737
quantize_activation_per_token_absmax,
3838
safe_int_mm,
39+
dequantize_affine,
3940
)
4041

4142
from torchao.quantization.smoothquant import (
@@ -385,11 +386,11 @@ def _test_dynamic_quant_per_tensor_numerics_impl(
385386
# to rounding
386387
assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1
387388
torch.testing.assert_close(
388-
y_scale, torch.tensor([y_ref.q_scale()], device=device, dtype=float_dtype)
389+
y_scale, torch.tensor(y_ref.q_scale(), device=device, dtype=float_dtype)
389390
)
390391
if y_zero_point is not None:
391392
assert torch.equal(
392-
y_zero_point, torch.tensor([y_ref.q_zero_point()], device=device)
393+
y_zero_point, torch.tensor(y_ref.q_zero_point(), device=device)
393394
)
394395
else:
395396
self.assertTrue(y_ref.q_zero_point() == 0)
@@ -558,8 +559,8 @@ def _test_dynamic_quant_per_channel_numerics_impl(
558559
assert torch.max(torch.abs(y_vals - y_ref.int_repr())) <= 1
559560

560561
# dequantize
561-
x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point)
562-
x_ref_dq = y_ref.dequantize()
562+
x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point, out_dtype=float_dtype)
563+
x_ref_dq = y_ref.dequantize().to(float_dtype)
563564
# off-by-one for scale is okay
564565
torch.testing.assert_close(
565566
x_dq, x_ref_dq, atol=torch.max(y_scale).item() * 1.01, rtol=0.0001
@@ -582,7 +583,8 @@ def test_dynamic_quant_per_channel_numerics_cuda(self):
582583
def _test_quantize_per_token_impl(self, device, dtype):
583584
x = torch.randn(3, 3, 3, device=device, dtype=dtype)
584585
xq, scales = quantize_activation_per_token_absmax(x)
585-
x_dq = dequantize_per_tensor(xq, scales, None).to(x.dtype)
586+
block_size = (1, 1, 3)
587+
x_dq = dequantize_affine(xq, block_size, scales, None, torch.int8, output_dtype=x.dtype)
586588
sqnr = compute_error(x, x_dq)
587589
self.assertTrue(sqnr >= 45.0)
588590

@@ -1173,7 +1175,7 @@ def forward(self, x):
11731175
model_qc = torch.compile(model, mode="max-autotune")
11741176
ref_q = model_qc(x).detach()
11751177

1176-
assert SQNR(ref_f, ref_q) > min_sqnr
1178+
assert SQNR(ref_f, ref_q) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
11771179

11781180
# load model structure
11791181
with torch.device('meta'):
@@ -1190,7 +1192,7 @@ def forward(self, x):
11901192
model_qc = torch.compile(model, mode="max-autotune")
11911193
test = model_qc(x).detach()
11921194

1193-
assert SQNR(ref_f, test) > min_sqnr
1195+
assert SQNR(ref_f, test) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
11941196
self.assertTrue(torch.equal(ref_q, test))
11951197

11961198
@parameterized.expand(COMMON_DEVICE_DTYPE)

test/quantization/test_quant_primitives.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def test_choose_qparams_group_sym(self):
6767
mapping_type = MappingType.SYMMETRIC
6868
dtype = torch.int8
6969
block_size = (1, 2)
70-
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
70+
eps = torch.finfo(torch.float32).eps
71+
precision = torch.float32
72+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
7173

72-
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)
74+
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision)
7375

7476
self.assertTrue(torch.equal(scale, scale_ref))
7577
self.assertTrue(torch.equal(zero_point, zp_ref))

torchao/quantization/quant_primitives.py

Lines changed: 63 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
"groupwise_affine_dequantize_tensor_from_qparams",
4141
"groupwise_affine_quantize_tensor",
4242
"groupwise_affine_dequantize_tensor",
43+
"choose_qparams_affine",
44+
"quantize_affine",
45+
"dequantize_affine",
4346
# TODO: need to clean up above functions
4447
] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])
4548

@@ -219,10 +222,13 @@ def dequantize_affine(
219222
if zero_point is not None:
220223
zero_point = zero_point.view(shape_after_reduction)
221224

222-
dequant = input.to(torch.float32)
223-
scale = scale.to(torch.float32)
225+
dequant = input.to(output_dtype)
226+
# print("dq_affine: dq size:", dequant.shape)
227+
# print("dq_affine: scale size:", scale.shape)
228+
# dequant = input.to(output_dtype)
229+
# scale = scale.to(output_dtype)
224230
if zero_point is not None:
225-
zero_point = zero_point.to(torch.float32)
231+
# zero_point = zero_point.to(output_dtype)
226232
dequant -= zero_point
227233
dequant *= scale
228234
dequant = dequant.view(original_shape)
@@ -260,9 +266,9 @@ def choose_qparams_affine(
260266
"""
261267
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
262268
if scale_dtype is None:
263-
scale_dtype = torch.float32
269+
scale_dtype = input.dtype
264270
if zero_point_dtype is None:
265-
zero_point_dtype = torch.float32
271+
zero_point_dtype = input.dtype
266272

267273
assert len(block_size) == input.dim()
268274
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
@@ -301,47 +307,18 @@ def dynamically_quantize_per_tensor(
301307
target_dtype,
302308
qscheme=torch.per_tensor_affine, # for now, reuse existing qscheme enum
303309
):
304-
# assumes affine quantization
305-
306-
# default setup for affine quantization of activations
307310
eps = torch.finfo(torch.float32).eps
308-
309-
if qscheme == torch.per_tensor_affine:
310-
# get min and max
311-
# TODO(future): make torch.aminmax work on cpu-half
312-
# min_val, max_val = torch.aminmax(x)
313-
min_val = torch.min(x)
314-
max_val = torch.max(x)
315-
316-
# calculate scale and zero point based on min and max
317-
# reference: https://fburl.com/code/srbiybme
318-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
319-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
320-
321-
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
322-
# TODO(future): make torch.clamp with scalar work on cpu-half
323-
scale = torch.clamp(scale, min=eps).reshape(1)
324-
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
325-
zero_point = torch.clamp(zero_point, quant_min, quant_max)
326-
327-
# quantize based on qmin/qmax/scale/zp
328-
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
329-
quant = torch.clamp(
330-
torch.round(x / scale) + zero_point, quant_min, quant_max
331-
).to(target_dtype)
332-
333-
else:
334-
assert qscheme == torch.per_tensor_symmetric, f"unsupported qscheme {qscheme}"
335-
# assert quant_min == -1 * quant_max, "unsupported quant_min/quant_max"
336-
amax = torch.max(torch.abs(x))
337-
scale = amax / (float(quant_max - quant_min) / 2)
338-
scale = torch.clamp(scale, min=eps).reshape(1)
339-
quant = torch.clamp(torch.round(x / scale), quant_min, quant_max).to(
340-
target_dtype
341-
)
342-
# do not create a tensor for zero_point as this is expensive
343-
zero_point = None
344-
311+
block_size = x.shape
312+
zero_point_dtype = torch.int32
313+
314+
qscheme_to_mapping_type = {
315+
torch.per_tensor_affine: MappingType.ASYMMETRIC,
316+
torch.per_tensor_symmetric: MappingType.SYMMETRIC,
317+
}
318+
assert qscheme in qscheme_to_mapping_type, f"unsupported qscheme {qscheme}"
319+
mapping_type = qscheme_to_mapping_type[qscheme]
320+
scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
321+
quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
345322
return quant, scale, zero_point
346323

347324

@@ -374,57 +351,46 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
374351
# assumes dense memory format
375352
# TODO(future): relax ^ as needed
376353

377-
# default setup for affine quantization of activations
378-
eps = torch.finfo(torch.float32).eps
354+
assert x.dim() == 2, "only support 2d Tensors"
379355

380-
# get min and max
381-
min_val, max_val = torch.aminmax(x, dim=1)
382-
383-
# calculate scale and zero point based on min and max
384-
# reference: https://fburl.com/code/srbiybme
385-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
386-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
387-
device = min_val_neg.device
388-
389-
# reference: https://fburl.com/code/4wll53rk
390-
max_val_pos = torch.max(-min_val_neg, max_val_pos)
391-
scale = max_val_pos / (float(quant_max - quant_min) / 2)
392-
# ensure scale is the same dtype as the original tensor
393-
scale = torch.clamp(scale, min=eps).to(x.dtype)
394-
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
395-
396-
# quantize based on qmin/qmax/scale/zp
397-
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
398-
x_div = x.transpose(0, 1) / scale
399-
x_round = torch.round(x_div)
400-
x_zp = x_round + zero_point
401-
x_zp = x_zp.transpose(0, 1)
402-
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
356+
eps = torch.finfo(torch.float32).eps
357+
block_size = (1, x.shape[1])
358+
zero_point_dtype = torch.int64
403359

360+
mapping_type = MappingType.SYMMETRIC
361+
scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
362+
quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
404363
return quant, scale, zero_point
405364

406365

407366
# reference: https://fburl.com/code/vfsygwd0
408367

409368

410369
def dequantize_per_tensor(int_repr, scale, zero_point, out_dtype=torch.float32):
411-
y = int_repr.to(out_dtype)
412-
if zero_point is not None:
413-
y -= zero_point
414-
return y * scale
370+
block_size = int_repr.shape
371+
input_dtype = int_repr.dtype
372+
assert scale.numel() == 1, f"scale size: {scale.numel()}"
373+
dequantized = dequantize_affine(int_repr, block_size, scale, zero_point, input_dtype, output_dtype=out_dtype)
374+
return dequantized
415375

416376

417377
# reference: https://fburl.com/code/org0fmi3
418378

419379

420380
def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float32):
421-
# assumes axis is 0
422-
y = int_repr.transpose(0, 1)
423-
y = y.to(out_dtype)
424-
y = y - zero_points
425-
y = y * scales
426-
y = y.transpose(0, 1)
427-
return y
381+
assert int_repr.dim() == 2, "only support 2d Tensors"
382+
# channel axis == 0
383+
# block_size before transpose should be (1, int_repr.shape[1]) for axis == 0 per channel quant
384+
# print("dq per chan: input repr shape:", int_repr.shape)
385+
# print("dq per chan: scales shape:", scales.shape)
386+
387+
int_repr = int_repr.t()
388+
# transpose for block_size as well
389+
block_size = (int_repr.shape[0], 1)
390+
input_dtype = int_repr.dtype
391+
dequantized = dequantize_affine(int_repr, block_size, scales, zero_points, input_dtype, output_dtype=out_dtype)
392+
dequantized = dequantized.t()
393+
return dequantized
428394

429395

430396
def quant_int8_dynamic_linear(
@@ -595,7 +561,7 @@ def quant_int8_per_token_matmul(
595561

596562

597563
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
598-
""" """
564+
"""This is tinygemm specific, we'll keep this for now"""
599565
if groupsize > w.shape[-1]:
600566
groupsize = w.shape[-1]
601567
assert groupsize > 1
@@ -644,6 +610,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
644610
n_bit=4,
645611
groupsize=128,
646612
):
613+
"""This is tinygemm specific, we'll keep this for now"""
647614
assert groupsize > 1
648615
# needed for GPTQ single column quantize
649616
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
@@ -679,6 +646,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
679646
n_bit=4,
680647
groupsize=128,
681648
):
649+
"""This is tinygemm specific, we'll keep this for now"""
682650
assert groupsize > 1
683651
# needed for GPTQ single column dequantize
684652
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
@@ -728,26 +696,19 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
728696
assert groupsize > 1
729697
assert w.shape[-1] % groupsize == 0
730698
assert w.dim() == 2
699+
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"
731700

732-
to_quant = w.reshape(-1, groupsize)
733-
assert torch.isnan(to_quant).sum() == 0
734-
735-
max_val = to_quant.amax(dim=1, keepdim=True)
736-
min_val = to_quant.amin(dim=1, keepdim=True)
737-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
738-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
739-
740-
max_val_abs = torch.max(-min_val_neg, max_val_pos)
741-
max_int = 2 ** (n_bit - 1) - 1
742-
min_int = -(2 ** (n_bit - 1))
743-
744-
scales = max_val_abs / (float(max_int - min_int) / 2)
745-
scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps))
746-
# TODO: make sure abs(scales) is not too small?
747-
zeros = torch.full_like(scales, 0)
748-
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape(
749-
w.shape[0], -1
750-
)
701+
mapping_type = MappingType.SYMMETRIC
702+
block_size = (1, groupsize)
703+
eps = torch.finfo(torch.float32).eps
704+
ranges = {}
705+
ranges[1] = (-1, 0)
706+
# generating ranges for bit 2 to 8
707+
for i in range(2, 9):
708+
ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1)
709+
quant_min, quant_max = ranges[n_bit]
710+
scale, zero_point = choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
711+
return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1)
751712

752713

753714
if TORCH_VERSION_AFTER_2_3:
@@ -796,7 +757,7 @@ def pack_int4_from_int8(int8_data: torch.Tensor) -> torch.Tensor:
796757

797758
@impl(quantized_decomposed_lib, "unpack_int4_to_int8", "CompositeExplicitAutograd")
798759
def unpack_int4_to_int8(int8_data: torch.Tensor) -> torch.Tensor:
799-
"""Get the original weight from the normalized float weight format"""
760+
""" Get the original weight from the normalized float weight format"""
800761
# since we are using int8 we will decode 2 entries per byte
801762
# Shift elements down 4 and select out the bottom 4 bits
802763
shape = int8_data.shape

torchao/quantization/subclass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,11 @@ def dequantize(self, dtype=None):
218218
"""
219219
Obtain the dequantized version of the quantized tensor subclass
220220
"""
221+
zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype)
222+
# zero_points = 0
223+
# TODO: fix dtype here? `to(self.dtype)` is not overwritten by `dtype` arg?
221224
dq_t = dequantize_per_channel(
222-
self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype
225+
self.int_data.t(), self.q_scales, zero_points, self.dtype if dtype is None else dtype
223226
).to(self.dtype)
224227
# data was transposed to dequantize so make sure shape is correct
225228
return dq_t if not self.transposed else dq_t.t()
@@ -292,6 +295,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
292295
Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight)
293296
)
294297
"""
298+
# because we call transpose in dequantization
295299
w_int_repr, w_scales, _ = dynamically_quantize_per_channel(
296300
input_float, qmin, qmax, torch.int8
297301
)

0 commit comments

Comments
 (0)