|
40 | 40 | "groupwise_affine_dequantize_tensor_from_qparams",
|
41 | 41 | "groupwise_affine_quantize_tensor",
|
42 | 42 | "groupwise_affine_dequantize_tensor",
|
| 43 | + "choose_qparams_affine", |
| 44 | + "quantize_affine", |
| 45 | + "dequantize_affine", |
43 | 46 | # TODO: need to clean up above functions
|
44 | 47 | ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])
|
45 | 48 |
|
@@ -219,10 +222,13 @@ def dequantize_affine(
|
219 | 222 | if zero_point is not None:
|
220 | 223 | zero_point = zero_point.view(shape_after_reduction)
|
221 | 224 |
|
222 |
| - dequant = input.to(torch.float32) |
223 |
| - scale = scale.to(torch.float32) |
| 225 | + dequant = input.to(output_dtype) |
| 226 | + # print("dq_affine: dq size:", dequant.shape) |
| 227 | + # print("dq_affine: scale size:", scale.shape) |
| 228 | + # dequant = input.to(output_dtype) |
| 229 | + # scale = scale.to(output_dtype) |
224 | 230 | if zero_point is not None:
|
225 |
| - zero_point = zero_point.to(torch.float32) |
| 231 | + # zero_point = zero_point.to(output_dtype) |
226 | 232 | dequant -= zero_point
|
227 | 233 | dequant *= scale
|
228 | 234 | dequant = dequant.view(original_shape)
|
@@ -260,9 +266,9 @@ def choose_qparams_affine(
|
260 | 266 | """
|
261 | 267 | quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
|
262 | 268 | if scale_dtype is None:
|
263 |
| - scale_dtype = torch.float32 |
| 269 | + scale_dtype = input.dtype |
264 | 270 | if zero_point_dtype is None:
|
265 |
| - zero_point_dtype = torch.float32 |
| 271 | + zero_point_dtype = input.dtype |
266 | 272 |
|
267 | 273 | assert len(block_size) == input.dim()
|
268 | 274 | shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
|
@@ -301,47 +307,18 @@ def dynamically_quantize_per_tensor(
|
301 | 307 | target_dtype,
|
302 | 308 | qscheme=torch.per_tensor_affine, # for now, reuse existing qscheme enum
|
303 | 309 | ):
|
304 |
| - # assumes affine quantization |
305 |
| - |
306 |
| - # default setup for affine quantization of activations |
307 | 310 | eps = torch.finfo(torch.float32).eps
|
308 |
| - |
309 |
| - if qscheme == torch.per_tensor_affine: |
310 |
| - # get min and max |
311 |
| - # TODO(future): make torch.aminmax work on cpu-half |
312 |
| - # min_val, max_val = torch.aminmax(x) |
313 |
| - min_val = torch.min(x) |
314 |
| - max_val = torch.max(x) |
315 |
| - |
316 |
| - # calculate scale and zero point based on min and max |
317 |
| - # reference: https://fburl.com/code/srbiybme |
318 |
| - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
319 |
| - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
320 |
| - |
321 |
| - scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) |
322 |
| - # TODO(future): make torch.clamp with scalar work on cpu-half |
323 |
| - scale = torch.clamp(scale, min=eps).reshape(1) |
324 |
| - zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) |
325 |
| - zero_point = torch.clamp(zero_point, quant_min, quant_max) |
326 |
| - |
327 |
| - # quantize based on qmin/qmax/scale/zp |
328 |
| - # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 |
329 |
| - quant = torch.clamp( |
330 |
| - torch.round(x / scale) + zero_point, quant_min, quant_max |
331 |
| - ).to(target_dtype) |
332 |
| - |
333 |
| - else: |
334 |
| - assert qscheme == torch.per_tensor_symmetric, f"unsupported qscheme {qscheme}" |
335 |
| - # assert quant_min == -1 * quant_max, "unsupported quant_min/quant_max" |
336 |
| - amax = torch.max(torch.abs(x)) |
337 |
| - scale = amax / (float(quant_max - quant_min) / 2) |
338 |
| - scale = torch.clamp(scale, min=eps).reshape(1) |
339 |
| - quant = torch.clamp(torch.round(x / scale), quant_min, quant_max).to( |
340 |
| - target_dtype |
341 |
| - ) |
342 |
| - # do not create a tensor for zero_point as this is expensive |
343 |
| - zero_point = None |
344 |
| - |
| 311 | + block_size = x.shape |
| 312 | + zero_point_dtype = torch.int32 |
| 313 | + |
| 314 | + qscheme_to_mapping_type = { |
| 315 | + torch.per_tensor_affine: MappingType.ASYMMETRIC, |
| 316 | + torch.per_tensor_symmetric: MappingType.SYMMETRIC, |
| 317 | + } |
| 318 | + assert qscheme in qscheme_to_mapping_type, f"unsupported qscheme {qscheme}" |
| 319 | + mapping_type = qscheme_to_mapping_type[qscheme] |
| 320 | + scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype) |
| 321 | + quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) |
345 | 322 | return quant, scale, zero_point
|
346 | 323 |
|
347 | 324 |
|
@@ -374,57 +351,46 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
374 | 351 | # assumes dense memory format
|
375 | 352 | # TODO(future): relax ^ as needed
|
376 | 353 |
|
377 |
| - # default setup for affine quantization of activations |
378 |
| - eps = torch.finfo(torch.float32).eps |
| 354 | + assert x.dim() == 2, "only support 2d Tensors" |
379 | 355 |
|
380 |
| - # get min and max |
381 |
| - min_val, max_val = torch.aminmax(x, dim=1) |
382 |
| - |
383 |
| - # calculate scale and zero point based on min and max |
384 |
| - # reference: https://fburl.com/code/srbiybme |
385 |
| - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
386 |
| - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
387 |
| - device = min_val_neg.device |
388 |
| - |
389 |
| - # reference: https://fburl.com/code/4wll53rk |
390 |
| - max_val_pos = torch.max(-min_val_neg, max_val_pos) |
391 |
| - scale = max_val_pos / (float(quant_max - quant_min) / 2) |
392 |
| - # ensure scale is the same dtype as the original tensor |
393 |
| - scale = torch.clamp(scale, min=eps).to(x.dtype) |
394 |
| - zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) |
395 |
| - |
396 |
| - # quantize based on qmin/qmax/scale/zp |
397 |
| - # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 |
398 |
| - x_div = x.transpose(0, 1) / scale |
399 |
| - x_round = torch.round(x_div) |
400 |
| - x_zp = x_round + zero_point |
401 |
| - x_zp = x_zp.transpose(0, 1) |
402 |
| - quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) |
| 356 | + eps = torch.finfo(torch.float32).eps |
| 357 | + block_size = (1, x.shape[1]) |
| 358 | + zero_point_dtype = torch.int64 |
403 | 359 |
|
| 360 | + mapping_type = MappingType.SYMMETRIC |
| 361 | + scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype) |
| 362 | + quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) |
404 | 363 | return quant, scale, zero_point
|
405 | 364 |
|
406 | 365 |
|
407 | 366 | # reference: https://fburl.com/code/vfsygwd0
|
408 | 367 |
|
409 | 368 |
|
410 | 369 | def dequantize_per_tensor(int_repr, scale, zero_point, out_dtype=torch.float32):
|
411 |
| - y = int_repr.to(out_dtype) |
412 |
| - if zero_point is not None: |
413 |
| - y -= zero_point |
414 |
| - return y * scale |
| 370 | + block_size = int_repr.shape |
| 371 | + input_dtype = int_repr.dtype |
| 372 | + assert scale.numel() == 1, f"scale size: {scale.numel()}" |
| 373 | + dequantized = dequantize_affine(int_repr, block_size, scale, zero_point, input_dtype, output_dtype=out_dtype) |
| 374 | + return dequantized |
415 | 375 |
|
416 | 376 |
|
417 | 377 | # reference: https://fburl.com/code/org0fmi3
|
418 | 378 |
|
419 | 379 |
|
420 | 380 | def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float32):
|
421 |
| - # assumes axis is 0 |
422 |
| - y = int_repr.transpose(0, 1) |
423 |
| - y = y.to(out_dtype) |
424 |
| - y = y - zero_points |
425 |
| - y = y * scales |
426 |
| - y = y.transpose(0, 1) |
427 |
| - return y |
| 381 | + assert int_repr.dim() == 2, "only support 2d Tensors" |
| 382 | + # channel axis == 0 |
| 383 | + # block_size before transpose should be (1, int_repr.shape[1]) for axis == 0 per channel quant |
| 384 | + # print("dq per chan: input repr shape:", int_repr.shape) |
| 385 | + # print("dq per chan: scales shape:", scales.shape) |
| 386 | + |
| 387 | + int_repr = int_repr.t() |
| 388 | + # transpose for block_size as well |
| 389 | + block_size = (int_repr.shape[0], 1) |
| 390 | + input_dtype = int_repr.dtype |
| 391 | + dequantized = dequantize_affine(int_repr, block_size, scales, zero_points, input_dtype, output_dtype=out_dtype) |
| 392 | + dequantized = dequantized.t() |
| 393 | + return dequantized |
428 | 394 |
|
429 | 395 |
|
430 | 396 | def quant_int8_dynamic_linear(
|
@@ -595,7 +561,7 @@ def quant_int8_per_token_matmul(
|
595 | 561 |
|
596 | 562 |
|
597 | 563 | def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
|
598 |
| - """ """ |
| 564 | + """This is tinygemm specific, we'll keep this for now""" |
599 | 565 | if groupsize > w.shape[-1]:
|
600 | 566 | groupsize = w.shape[-1]
|
601 | 567 | assert groupsize > 1
|
@@ -644,6 +610,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
|
644 | 610 | n_bit=4,
|
645 | 611 | groupsize=128,
|
646 | 612 | ):
|
| 613 | + """This is tinygemm specific, we'll keep this for now""" |
647 | 614 | assert groupsize > 1
|
648 | 615 | # needed for GPTQ single column quantize
|
649 | 616 | if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
@@ -679,6 +646,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
|
679 | 646 | n_bit=4,
|
680 | 647 | groupsize=128,
|
681 | 648 | ):
|
| 649 | + """This is tinygemm specific, we'll keep this for now""" |
682 | 650 | assert groupsize > 1
|
683 | 651 | # needed for GPTQ single column dequantize
|
684 | 652 | if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
|
@@ -728,26 +696,19 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
|
728 | 696 | assert groupsize > 1
|
729 | 697 | assert w.shape[-1] % groupsize == 0
|
730 | 698 | assert w.dim() == 2
|
| 699 | + assert n_bit <= 8, f"unsupported n_bit: {n_bit}" |
731 | 700 |
|
732 |
| - to_quant = w.reshape(-1, groupsize) |
733 |
| - assert torch.isnan(to_quant).sum() == 0 |
734 |
| - |
735 |
| - max_val = to_quant.amax(dim=1, keepdim=True) |
736 |
| - min_val = to_quant.amin(dim=1, keepdim=True) |
737 |
| - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
738 |
| - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
739 |
| - |
740 |
| - max_val_abs = torch.max(-min_val_neg, max_val_pos) |
741 |
| - max_int = 2 ** (n_bit - 1) - 1 |
742 |
| - min_int = -(2 ** (n_bit - 1)) |
743 |
| - |
744 |
| - scales = max_val_abs / (float(max_int - min_int) / 2) |
745 |
| - scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps)) |
746 |
| - # TODO: make sure abs(scales) is not too small? |
747 |
| - zeros = torch.full_like(scales, 0) |
748 |
| - return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape( |
749 |
| - w.shape[0], -1 |
750 |
| - ) |
| 701 | + mapping_type = MappingType.SYMMETRIC |
| 702 | + block_size = (1, groupsize) |
| 703 | + eps = torch.finfo(torch.float32).eps |
| 704 | + ranges = {} |
| 705 | + ranges[1] = (-1, 0) |
| 706 | + # generating ranges for bit 2 to 8 |
| 707 | + for i in range(2, 9): |
| 708 | + ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1) |
| 709 | + quant_min, quant_max = ranges[n_bit] |
| 710 | + scale, zero_point = choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision) |
| 711 | + return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1) |
751 | 712 |
|
752 | 713 |
|
753 | 714 | if TORCH_VERSION_AFTER_2_3:
|
@@ -796,7 +757,7 @@ def pack_int4_from_int8(int8_data: torch.Tensor) -> torch.Tensor:
|
796 | 757 |
|
797 | 758 | @impl(quantized_decomposed_lib, "unpack_int4_to_int8", "CompositeExplicitAutograd")
|
798 | 759 | def unpack_int4_to_int8(int8_data: torch.Tensor) -> torch.Tensor:
|
799 |
| - """Get the original weight from the normalized float weight format""" |
| 760 | + """ Get the original weight from the normalized float weight format""" |
800 | 761 | # since we are using int8 we will decode 2 entries per byte
|
801 | 762 | # Shift elements down 4 and select out the bottom 4 bits
|
802 | 763 | shape = int8_data.shape
|
|
0 commit comments