@@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn):
177
177
fn (self .zero_point ),
178
178
)
179
179
180
+ def _change_shape (self , shape ):
181
+ return self .__class__ (
182
+ self .int_data .view (shape ), self .scale , self .zero_point
183
+ )
184
+
180
185
@classmethod
181
186
def __torch_dispatch__ (cls , func , types , args , kwargs ):
182
187
kwargs = {} if kwargs is None else kwargs
@@ -186,6 +191,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
186
191
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
187
192
)
188
193
194
+ if func is aten .view .default :
195
+ assert len (args ) == 2
196
+ new = args [0 ]._change_shape (args [1 ])
197
+ return return_and_correct_aliasing (func , args , kwargs , new )
198
+
189
199
raise NotImplementedError (
190
200
f"PlainAQTLayout dispatch: attempting to run { func } , this is not supported"
191
201
)
@@ -245,6 +255,7 @@ def __tensor_unflatten__(
245
255
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
246
256
):
247
257
packed_weight , scale_and_zero = tensor_data_dict ["packed_weight" ], tensor_data_dict ["scale_and_zero" ]
258
+ # TODO: fix the unflatten logic
248
259
return cls (packed_weight , scale_and_zero )
249
260
250
261
def to (self , * args , ** kwargs ):
@@ -470,6 +481,11 @@ def _apply_fn_to_data(self, fn):
470
481
strides = self .stride (),
471
482
)
472
483
484
+ def _change_shape (self , shape , block_size ):
485
+ return self .__class__ (
486
+ self .layout_tensor .view (shape ), block_size , shape , self .quant_min , self .quant_max , self .zero_point_domain , dtype = self .dtype , strides = self .stride ()
487
+ )
488
+
473
489
@classmethod
474
490
def __torch_dispatch__ (cls , func , types , args , kwargs ):
475
491
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
@@ -491,13 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
491
507
f"AffineQuantizedTensor dispatch: attempting to run { func } , this is not supported"
492
508
)
493
509
494
- @implements_aqt_torch_function (torch .nn .functional .linear )
495
- def functional_linear (* args , ** kwargs ):
496
- input_tensor , weight_qtensor , bias = (
497
- args [0 ],
498
- args [1 ],
499
- args [2 ] if len (args ) > 2 else None ,
500
- )
510
+ def _quantized_linear_op (input_tensor , weight_qtensor , bias ):
501
511
is_cuda = weight_qtensor .is_cuda
502
512
is_cpu = weight_qtensor .device == torch .device ("cpu" )
503
513
if isinstance (weight_qtensor , AffineQuantizedTensor ):
@@ -508,14 +518,10 @@ def functional_linear(*args, **kwargs):
508
518
# if input tensor is quantized, either dispatch to the int8 mm kernel
509
519
# or just dequantize the input tensor
510
520
input_is_int8 = _aqt_is_int8_reduced_range (input_tensor )
511
- input_tensor_dtype_is_expected = input_tensor .dtype in [
512
- torch .float ,
513
- torch .bfloat16
514
- ]
515
521
if (
516
522
is_cuda and
517
523
input_is_int8 and
518
- input_tensor_dtype_is_expected and
524
+ input_tensor . dtype == weight_qtensor . dtype and
519
525
input_tensor .layout == "plain" and
520
526
weight_qtensor .layout == "plain"
521
527
):
@@ -576,45 +582,83 @@ def functional_linear(*args, **kwargs):
576
582
weight_qtensor .block_size [1 ] == weight_qtensor .shape [1 ] and
577
583
weight_qtensor .layout == "plain"
578
584
):
579
- # TODO: enable mps path as well
585
+ # TODO: enable cpu and mps efficient path
580
586
# per channel int8 weight only quantizated mm
581
- return torch .ops .aten ._weight_int8pack_mm (input_tensor .contiguous (), weight_qtensor .layout_tensor .int_data , weight_qtensor .layout_tensor .scale )
582
- else :
583
- weight_tensor = weight_qtensor .dequantize ()
584
- return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
585
- else :
587
+ w_vals_int8_t = weight_qtensor .layout_tensor .int_data .t ().contiguous ()
588
+ orig_dtype = input_tensor .dtype
589
+ y = (
590
+ torch .mm (
591
+ input_tensor .reshape (- 1 , input_tensor .shape [- 1 ]),
592
+ w_vals_int8_t .to (input_tensor .dtype ),
593
+ )
594
+ * weight_qtensor .scale
595
+ )
596
+ y = y .reshape (* input_tensor .shape [:- 1 ], y .shape [- 1 ])
597
+ if bias is not None :
598
+ y += bias
599
+ return y .to (orig_dtype )
600
+
601
+ # is_cpu and is_mps only, some issue with is_contiguous() currently
602
+ # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)
603
+
604
+ raise NotImplementedError ("No specialized dispatch found for quantized linear op" )
605
+
606
+
607
+ @implements_aqt_torch_function (torch .nn .functional .linear )
608
+ def functional_linear (* args , ** kwargs ):
609
+ input_tensor , weight_tensor , bias = (
610
+ args [0 ],
611
+ args [1 ],
612
+ args [2 ] if len (args ) > 2 else None ,
613
+ )
614
+ # using try/except here so that we can have a general fallback when input_tensor/weight_tensor
615
+ # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
616
+ # make the branches easier to understand in `_quantized_linear_op`
617
+ try :
618
+ return _quantized_linear_op (input_tensor , weight_tensor , bias )
619
+ except :
586
620
if isinstance (input_tensor , AffineQuantizedTensor ):
587
621
input_tensor = input_tensor .dequantize ()
622
+ if isinstance (weight_tensor , AffineQuantizedTensor ):
623
+ weight_tensor = weight_tensor .dequantize ()
588
624
return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
589
625
590
-
591
626
@implements_aqt_aten_ops ([aten .mm .default , aten .addmm .default ])
592
627
def aten_mm (func , * args , ** kwargs ):
593
628
if not args [0 ].is_floating_point ():
594
629
raise NotImplementedError (f"{ func } is not implemented for non floating point input" )
595
630
631
+ # using try/except here so that we can have a general fallback when input_tensor/weight_tensor
632
+ # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
633
+ # make the branches easier to understand in `_quantized_linear_op`
596
634
if func == aten .addmm .default :
597
- assert args [1 ].shape [- 1 ] == args [2 ].shape [0 ], (
598
- f"need mat1 shape: { args [1 ].shape } final"
599
- f"dim to match mat2 shape: { args [2 ].shape } first dim "
600
- )
601
- input_tensor , weight_qtensor , bias = (
635
+ input_tensor , weight_tensor , bias = (
602
636
args [1 ],
603
637
args [2 ],
604
638
args [0 ],
605
639
)
640
+ try :
641
+ return _quantized_linear_op (input_tensor , weight_tensor , bias )
642
+ except :
643
+ if isinstance (input_tensor , AffineQuantizedTensor ):
644
+ input_tensor = input_tensor .dequantize ()
645
+ if isinstance (weight_tensor , AffineQuantizedTensor ):
646
+ weight_tensor = weight_tensor .dequantize ()
647
+ return func (bias , input_tensor , weight_tensor )
606
648
else :
607
- assert args [0 ].shape [- 1 ] == args [1 ].shape [0 ], (
608
- f"need mat1 shape: { args [0 ].shape } final dim"
609
- f"to match mat2 shape: { args [1 ].shape } first dim"
610
- )
611
- input_tensor , weight_qtensor , bias = (
649
+ input_tensor , weight_tensor , bias = (
612
650
args [0 ],
613
651
args [1 ],
614
- None if len ( args ) == 2 else args [ 2 ],
652
+ None
615
653
)
616
- weight_tensor = weight_qtensor .dequantize ()
617
- return func (input_tensor , weight_tensor , bias )
654
+ try :
655
+ return _quantized_linear_op (input_tensor , weight_tensor , bias )
656
+ except :
657
+ if isinstance (input_tensor , AffineQuantizedTensor ):
658
+ input_tensor = input_tensor .dequantize ()
659
+ if isinstance (weight_tensor , AffineQuantizedTensor ):
660
+ weight_tensor = weight_tensor .dequantize ()
661
+ return func (input_tensor , weight_tensor )
618
662
619
663
@implements_aqt_aten_ops ([aten .detach .default ])
620
664
def detach (func , * args , ** kwargs ):
@@ -641,10 +685,10 @@ def _to_copy(func, *args, **kwargs):
641
685
642
686
@implements_aqt_aten_ops ([aten .t .default ])
643
687
def t (func , * args , ** kwargs ):
644
- # TODO: need to implement this
645
- # args[0].transposed = not args[0].transposed
646
- # new = args[0]._change_shape(args[0].shape[::-1 ])
647
- # return return_and_correct_aliasing(func, args, kwargs, new )
648
- raise Exception ( "transpose not implemented yet" )
688
+ block_size = args [ 0 ]. block_size
689
+ assert len ( block_size ) == 2
690
+ transposed_block_size = ( block_size [ 1 ], block_size [ 0 ])
691
+ new = args [ 0 ]. _change_shape ( args [ 0 ]. shape [:: - 1 ], transposed_block_size )
692
+ return return_and_correct_aliasing ( func , args , kwargs , new )
649
693
650
694
to_aq = AffineQuantizedTensor .from_float
0 commit comments