20
20
_ATEN_OP_OR_TORCH_FN_TABLE ,
21
21
_register_layout_cls ,
22
22
_get_layout_tensor_constructor ,
23
+ LayoutType ,
23
24
)
25
+ from typing import ClassVar
26
+ from dataclasses import dataclass
24
27
25
28
aten = torch .ops .aten
26
29
30
+ @dataclass (frozen = True )
31
+ class PlainLayoutType (LayoutType ):
32
+ pass
33
+
34
+ @dataclass (frozen = True )
35
+ class TensorCoreTiledLayoutType (LayoutType ):
36
+ inner_k_tiles : int = 8
37
+
38
+ def pad_input (self , input : torch .Tensor ) -> torch .Tensor :
39
+ orig_out_features , orig_in_features = input .shape
40
+ in_features = find_multiple (orig_in_features , 1024 )
41
+ out_features = find_multiple (orig_out_features , 8 )
42
+ input = torch .nn .functional .pad (
43
+ input ,
44
+ (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
45
+ )
46
+ return input
47
+
48
+
27
49
def _aqt_is_int8 (aqt ):
28
50
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
29
51
return (
@@ -52,9 +74,6 @@ class AQTLayout(torch.Tensor):
52
74
"""
53
75
Base class for the layout tensor for `AffineQuantizedTensor`
54
76
"""
55
- # this should be set for each layout class during registration
56
- extended_layout : Optional [str ] = None
57
-
58
77
def get_plain () -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
59
78
pass
60
79
@@ -64,6 +83,7 @@ def from_plain(
64
83
int_data : torch .Tensor ,
65
84
scale : torch .Tensor ,
66
85
zero_point : torch .Tensor ,
86
+ layout_type : LayoutType ,
67
87
):
68
88
pass
69
89
@@ -194,30 +214,16 @@ def from_float(
194
214
zero_point_dtype : Optional [torch .dtype ] = None ,
195
215
preserve_zero : bool = True ,
196
216
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
197
- extended_layout : str = "plain" ,
198
- # TODO: this is only for "tensor_core_tiled", need to figure out
199
- # the proper API for this arg
200
- inner_k_tiles : Optional [int ] = None ,
217
+ layout_type : LayoutType = PlainLayoutType (),
201
218
):
202
219
original_shape = input_float .shape
203
- if extended_layout == "tensor_core_tiled" :
204
- orig_out_features , orig_in_features = input_float .shape
205
- in_features = find_multiple (orig_in_features , 1024 )
206
- out_features = find_multiple (orig_out_features , 8 )
207
- input_float = torch .nn .functional .pad (
208
- input_float ,
209
- (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
210
- )
220
+ input_float = layout_type .pad_input (input_float )
211
221
212
222
scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , scale_dtype , zero_point_dtype , preserve_zero , zero_point_domain )
213
223
int_data = quantize_affine (input_float , block_size , scale , zero_point , target_dtype , quant_min , quant_max , zero_point_domain )
214
224
215
- layout_cls_ctr = get_layout_tensor_constructor (extended_layout )
216
- # TODO: this is temporary, need to come up with the proper UX
217
- if extended_layout == "tensor_core_tiled" :
218
- layout_tensor = layout_cls_ctr (int_data , scale , zero_point , inner_k_tiles )
219
- else :
220
- layout_tensor = layout_cls_ctr (int_data , scale , zero_point )
225
+ layout_tensor_ctr = get_layout_tensor_constructor (type (layout_type ))
226
+ layout_tensor = layout_tensor_ctr (int_data , scale , zero_point , layout_type )
221
227
return cls (
222
228
layout_tensor ,
223
229
block_size ,
@@ -229,8 +235,8 @@ def from_float(
229
235
)
230
236
231
237
@property
232
- def extended_layout (self ) -> str :
233
- return self .layout_tensor .extended_layout
238
+ def layout_type (self ) -> str :
239
+ return self .layout_tensor .layout_type
234
240
235
241
@classmethod
236
242
def __torch_function__ (cls , func , types , args = (), kwargs = None ):
@@ -308,13 +314,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308
314
def implements (aten_ops_or_torch_fn ):
309
315
return _implements (AffineQuantizedTensor , aten_ops_or_torch_fn )
310
316
311
- def register_layout_cls (extended_layout : str ):
312
- return _register_layout_cls (AffineQuantizedTensor , extended_layout )
317
+ def register_layout_cls (layout_type_class : type ( LayoutType ) ):
318
+ return _register_layout_cls (AffineQuantizedTensor , layout_type_class )
313
319
314
- def get_layout_tensor_constructor (extended_layout : str ):
315
- return _get_layout_tensor_constructor (AffineQuantizedTensor , extended_layout )
320
+ def get_layout_tensor_constructor (layout_type_class : type ( LayoutType ) ):
321
+ return _get_layout_tensor_constructor (AffineQuantizedTensor , layout_type_class )
316
322
317
- @register_layout_cls ("plain" )
323
+ @register_layout_cls (PlainLayoutType )
318
324
class PlainAQTLayout (AQTLayout ):
319
325
"""
320
326
Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
@@ -330,6 +336,7 @@ def __new__(
330
336
int_data : torch .Tensor ,
331
337
scale : torch .Tensor ,
332
338
zero_point : torch .Tensor ,
339
+ layout_type : LayoutType ,
333
340
):
334
341
kwargs = {}
335
342
kwargs ["device" ] = int_data .device
@@ -346,34 +353,39 @@ def __init__(
346
353
int_data : torch .Tensor ,
347
354
scale : torch .Tensor ,
348
355
zero_point : torch .Tensor ,
356
+ layout_type : LayoutType ,
349
357
):
350
358
self .int_data = int_data
351
359
self .scale = scale
352
360
self .zero_point = zero_point
361
+ self .layout_type = layout_type
353
362
354
363
def __tensor_flatten__ (self ):
355
- return ["int_data" , "scale" , "zero_point" ], []
364
+ return ["int_data" , "scale" , "zero_point" ], [self . layout_type ]
356
365
357
366
@classmethod
358
367
def __tensor_unflatten__ (
359
368
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
360
369
):
361
370
int_data , scale , zero_point = tensor_data_dict ["int_data" ], tensor_data_dict ["scale" ], tensor_data_dict ["zero_point" ]
362
- return cls (int_data , scale , zero_point )
371
+ layout_type , = tensor_attributes
372
+ return cls (int_data , scale , zero_point , layout_type )
363
373
364
374
def to (self , * args , ** kwargs ):
365
375
kwargs = self ._get_to_kwargs (* args , ** kwargs )
366
376
return self .__class__ (
367
377
self .int_data .to (kwargs ["device" ]),
368
378
self .scale .to (kwargs ["device" ]),
369
379
self .zero_point .to (kwargs ["device" ]),
380
+ self .layout_type ,
370
381
)
371
382
372
383
def _apply_fn_to_data (self , fn ):
373
384
return self .__class__ (
374
385
fn (self .int_data ),
375
386
fn (self .scale ),
376
387
fn (self .zero_point ),
388
+ self .layout_type ,
377
389
)
378
390
379
391
@classmethod
@@ -407,10 +419,12 @@ def from_plain(
407
419
int_data : torch .Tensor ,
408
420
scale : torch .Tensor ,
409
421
zero_point : torch .Tensor ,
422
+ layout_type : LayoutType ,
410
423
):
411
- return cls (int_data , scale , zero_point )
424
+ assert isinstance (layout_type , PlainLayoutType )
425
+ return cls (int_data , scale , zero_point , layout_type )
412
426
413
- @register_layout_cls ("tensor_core_tiled" )
427
+ @register_layout_cls (TensorCoreTiledLayoutType )
414
428
class TensorCoreTiledAQTLayout (AQTLayout ):
415
429
"""
416
430
Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
@@ -427,6 +441,7 @@ def __new__(
427
441
packed_weight : torch .Tensor ,
428
442
scale_and_zero : torch .Tensor ,
429
443
transposed : bool ,
444
+ layout_type : LayoutType ,
430
445
):
431
446
kwargs = {}
432
447
kwargs ["device" ] = packed_weight .device
@@ -443,29 +458,38 @@ def __init__(
443
458
packed_weight : torch .Tensor ,
444
459
scale_and_zero : torch .Tensor ,
445
460
transposed : bool ,
461
+ layout_type : LayoutType ,
446
462
):
447
463
self .packed_weight = packed_weight
448
464
self .scale_and_zero = scale_and_zero
449
465
self .transposed = False
466
+ self .layout_type = layout_type
450
467
451
468
def __tensor_flatten__ (self ):
452
- return ["packed_weight" , "scale_and_zero" ], [self .transposed ]
469
+ return ["packed_weight" , "scale_and_zero" ], [self .transposed , self . layout_type ]
453
470
454
471
@classmethod
455
472
def __tensor_unflatten__ (
456
473
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
457
474
):
458
475
packed_weight , scale_and_zero = tensor_data_dict ["packed_weight" ], tensor_data_dict ["scale_and_zero" ]
459
- transposed , = tensor_attributes
460
- return cls (packed_weight , scale_and_zero , transposed )
476
+ transposed , layout_type , = tensor_attributes
477
+ return cls (packed_weight , scale_and_zero , transposed , layout_type )
461
478
462
479
@classmethod
463
- def from_plain (cls , int_data , scale , zero_point , inner_k_tiles = 8 ):
464
- packed_weight = torch .ops .aten ._convert_weight_to_int4pack (int_data .to (torch .int32 ), inner_k_tiles )
480
+ def from_plain (
481
+ cls ,
482
+ int_data : torch .Tensor ,
483
+ scale : torch .Tensor ,
484
+ zero_point : torch .Tensor ,
485
+ layout_type : LayoutType
486
+ ):
487
+ assert isinstance (layout_type , TensorCoreTiledLayoutType )
488
+ packed_weight = torch .ops .aten ._convert_weight_to_int4pack (int_data .to (torch .int32 ), layout_type .inner_k_tiles )
465
489
scale = scale .reshape (int_data .shape [0 ], - 1 )
466
490
zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
467
491
scale_and_zero = pack_tinygemm_scales_and_zeros (scale , zero_point )
468
- return cls (packed_weight , scale_and_zero , False )
492
+ return cls (packed_weight , scale_and_zero , False , layout_type )
469
493
470
494
def to (self , * args , ** kwargs ):
471
495
kwargs = self ._get_to_kwargs (* args , ** kwargs )
@@ -475,7 +499,8 @@ def to(self, *args, **kwargs):
475
499
return self .__class__ (
476
500
self .packed_weight .to (device ),
477
501
self .scale_and_zero .to (device ),
478
- self .transposed
502
+ self .transposed ,
503
+ self .layout_type ,
479
504
)
480
505
481
506
def _apply_fn_to_data (self , fn ):
@@ -485,7 +510,7 @@ def _apply_fn_to_data(self, fn):
485
510
486
511
def __repr__ (self ):
487
512
int_data , scale , zero_point = self .get_plain ()
488
- return f"TensorCoreTiledAQTLayout(int_data={ int_data } , scale={ scale } , zero_point={ zero_point } )"
513
+ return f"TensorCoreTiledAQTLayout(int_data={ int_data } , scale={ scale } , zero_point={ zero_point } , { self . layout_type } )"
489
514
490
515
@classmethod
491
516
def __torch_dispatch__ (cls , func , types , args , kwargs ):
@@ -563,8 +588,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
563
588
is_cuda and
564
589
input_is_int8 and
565
590
input_tensor .dtype == weight_qtensor .dtype and
566
- input_tensor .extended_layout == "plain" and
567
- weight_qtensor .extended_layout == "plain"
591
+ isinstance ( input_tensor .layout_type , PlainLayoutType ) and
592
+ isinstance ( weight_qtensor .layout_type , PlainLayoutType )
568
593
):
569
594
#
570
595
# 1. do the matrix form of dot(X_i, W_j)
@@ -606,7 +631,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
606
631
weight_qtensor .dtype == torch .bfloat16 and
607
632
len (weight_qtensor .shape ) == 2 and
608
633
weight_qtensor .zero_point_domain == ZeroPointDomain .FLOAT and
609
- weight_qtensor .extended_layout == "tensor_core_tiled"
634
+ isinstance ( weight_qtensor .layout_type , TensorCoreTiledLayoutType )
610
635
):
611
636
assert weight_qtensor .block_size [0 ] == 1 , f"Requires groupwise quantization, got block_size: { block_size } "
612
637
assert input_tensor .shape [- 1 ] == weight_qtensor .shape [1 ], (
@@ -649,7 +674,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
649
674
weight_qtensor .block_size [0 ] == 1 and
650
675
weight_qtensor .block_size [1 ] == weight_qtensor .shape [1 ] and
651
676
weight_qtensor .zero_point_domain == ZeroPointDomain .INT and
652
- weight_qtensor .extended_layout == "plain"
677
+ isinstance ( weight_qtensor .layout_type , PlainLayoutType )
653
678
):
654
679
# TODO: enable cpu and mps efficient path
655
680
# per channel int8 weight only quantizated mm
0 commit comments