4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , Optional , Tuple
7
+ from typing import Any , List , Optional , Tuple
8
8
9
9
import torch
10
10
import torch .nn .functional as F
25
25
ZeroPointDomain ,
26
26
)
27
27
from torchao .quantization .unified import TwoStepQuantizer
28
- from torchao .quantization .utils import get_group_qparams_symmetric
28
+ from torchao .quantization .utils import (
29
+ _get_per_token_block_size ,
30
+ get_group_qparams_symmetric ,
31
+ )
29
32
30
33
31
34
# =================
@@ -346,8 +349,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
346
349
scales , zero_points = get_groupwise_affine_qparams (
347
350
self .weight , n_bit , self .groupsize , self .scales_precision ,
348
351
)
349
- w_fq = _Int4WeightOnlyFakeQuantize .apply (
350
- self .weight , scales , zero_points , qmin , qmax , self .groupsize ,
352
+ w_fq = fake_quantize_per_channel_group (
353
+ self .weight ,
354
+ scales ,
355
+ zero_points ,
356
+ qmin ,
357
+ qmax ,
358
+ self .groupsize ,
359
+ ZeroPointDomain .FLOAT ,
351
360
)
352
361
return F .linear (x , w_fq )
353
362
@@ -370,39 +379,6 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
370
379
# | QUANT PRIMITIVES |
371
380
# ========================
372
381
373
- class _Int4WeightOnlyFakeQuantize (torch .autograd .Function ):
374
- """
375
- Implementation of int4 grouped per channel weight-only fake quantize
376
- intended to match the numerics of the efficient int4 tinygemm kernel.
377
- """
378
-
379
- @staticmethod
380
- def forward (ctx , input , scales , zero_points , quant_min , quant_max , groupsize ):
381
- assert groupsize > 1
382
- assert input .shape [- 1 ] % groupsize == 0
383
- assert input .dim () == 2
384
- n_bit = 4
385
- block_size = (1 , groupsize )
386
- quant_min = 0
387
- quant_max = 2 ** n_bit - 1
388
- (fq , mask ) = fake_quantize_affine_cachemask (
389
- input ,
390
- block_size ,
391
- scales ,
392
- zero_points ,
393
- torch .int32 ,
394
- quant_min ,
395
- quant_max ,
396
- zero_point_domain = ZeroPointDomain .FLOAT ,
397
- )
398
- ctx .save_for_backward (mask )
399
- return fq
400
-
401
- @staticmethod
402
- def backward (ctx , gy ):
403
- (mask ,) = ctx .saved_tensors
404
- return gy * mask , None , None , None , None , None
405
-
406
382
class _GenericFakeQuantize (torch .autograd .Function ):
407
383
"""
408
384
Implementation of generic fake quantize with backward STE.
@@ -412,71 +388,73 @@ class _GenericFakeQuantize(torch.autograd.Function):
412
388
"""
413
389
414
390
@staticmethod
415
- def forward (ctx , input , scales , zero_points , quant_min , quant_max ):
391
+ def forward (
392
+ ctx : torch .autograd .function .FunctionCtx ,
393
+ input : torch .Tensor ,
394
+ scales : torch .Tensor ,
395
+ zero_points : torch .Tensor ,
396
+ quant_min : int ,
397
+ quant_max : int ,
398
+ block_size : List [int ],
399
+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
400
+ ) -> torch .Tensor :
416
401
# Note: for bf16 inputs, casting them to fp32 has the unexpected
417
402
# side effect of reducing memory footprint significantly, presumably
418
403
# because bf16 * fp32 kernels are not as memory efficient
419
404
assert input .dtype == torch .float32
420
405
assert scales .dtype == torch .float32
421
406
assert zero_points .dtype == torch .int32
422
- q = input .mul (1.0 / scales ).round ().add (zero_points )
423
- dq = q .clamp (quant_min , quant_max ).sub (zero_points ).mul (scales )
424
- mask = torch .logical_and ((q >= quant_min ), (q <= quant_max ))
407
+
408
+ (fq , mask ) = fake_quantize_affine_cachemask (
409
+ input ,
410
+ block_size ,
411
+ scales ,
412
+ zero_points ,
413
+ torch .int32 ,
414
+ quant_min ,
415
+ quant_max ,
416
+ zero_point_domain ,
417
+ )
418
+
425
419
ctx .save_for_backward (mask )
426
- return dq
420
+ return fq
427
421
428
422
@staticmethod
429
423
def backward (ctx , gy ):
430
424
(mask ,) = ctx .saved_tensors
431
- return gy * mask , None , None , None , None , None
432
-
433
- # TODO: move this to core
434
- quantized_decomposed_lib .define (
435
- "fake_quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, "
436
- "int quant_min, int quant_max, int group_size) -> Tensor"
437
- )
425
+ return gy * mask , None , None , None , None , None , None
438
426
439
- @impl (quantized_decomposed_lib , "fake_quantize_per_channel_group" , "CompositeImplicitAutograd" )
440
427
def fake_quantize_per_channel_group (
441
428
input : torch .Tensor ,
442
429
scales : torch .Tensor ,
443
430
zero_points : torch .Tensor ,
444
431
quant_min : int ,
445
432
quant_max : int ,
446
433
group_size : int ,
434
+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
447
435
) -> torch .Tensor :
448
436
assert group_size > 1
449
437
assert input .shape [- 1 ] % group_size == 0
450
438
assert input .dim () == 2
451
- grouped_input = input .reshape (- 1 , group_size ).to (torch .float32 )
452
- scales = scales .reshape (- 1 , 1 )
453
- zero_points = zero_points .reshape (- 1 , 1 )
454
- fq = _GenericFakeQuantize .apply (
455
- grouped_input , scales , zero_points , quant_min , quant_max ,
439
+ block_size = (1 , group_size )
440
+ return _GenericFakeQuantize .apply (
441
+ input , scales , zero_points , quant_min , quant_max , block_size , zero_point_domain ,
456
442
)
457
- return fq .reshape_as (input ).to (input .dtype )
458
-
459
- # TODO: move this to core
460
- quantized_decomposed_lib .define (
461
- "fake_quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
462
- "int quant_min, int quant_max) -> Tensor"
463
- )
464
443
465
- @impl (quantized_decomposed_lib , "fake_quantize_per_token" , "CompositeImplicitAutograd" )
466
444
def fake_quantize_per_token (
467
445
input : torch .Tensor ,
468
446
scales : torch .Tensor ,
469
447
zero_points : torch .Tensor ,
470
448
quant_min : int ,
471
449
quant_max : int ,
472
450
) -> torch .Tensor :
473
- # TODO: we won't need this import anymore once we move this to core
474
451
from torch .ao .quantization .fx ._decomposed import _per_token_quant_qparam_dim_check
475
452
476
453
_per_token_quant_qparam_dim_check (input , scales , zero_points )
454
+ block_size = _get_per_token_block_size (input )
477
455
fq_input = input .to (torch .float32 )
478
456
fq = _GenericFakeQuantize .apply (
479
- fq_input , scales , zero_points , quant_min , quant_max ,
457
+ fq_input , scales , zero_points , quant_min , quant_max , block_size ,
480
458
)
481
459
return fq .reshape_as (input ).to (input .dtype )
482
460
0 commit comments