18
18
def _aqt_is_int8 (aqt ):
19
19
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
20
20
return (
21
- aqt .int_data .dtype == torch .int8 and
21
+ aqt .storage_tensor .dtype == torch .int8 and
22
22
aqt .quant_min is None or aqt .quant_min == - 128 and
23
23
aqt .quant_max is None or aqt .quant_max == 127
24
24
)
25
25
26
26
def _aqt_is_int8_reduced_range (aqt ):
27
27
return (
28
- aqt .int_data .dtype == torch .int8 and
28
+ aqt .storage_tensor .dtype == torch .int8 and
29
29
aqt .quant_min == - 127 and
30
30
aqt .quant_max is None or aqt .quant_max == 127
31
31
)
@@ -34,7 +34,7 @@ def _aqt_is_uint4(aqt):
34
34
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
35
35
# TODO: use torch.uint4
36
36
return (
37
- aqt .int_data .dtype == torch .int32 and
37
+ aqt .storage_tensor .dtype == torch .int32 and
38
38
aqt .quant_min is None or aqt .quant_min == 0 and
39
39
aqt .quant_max is None or aqt .quant_max == 15
40
40
)
@@ -69,6 +69,121 @@ def implements_aqt_aten_ops(aten_ops):
69
69
def implements_aqt_torch_function (torch_function ):
70
70
return implements_torch_function (AffineQuantizedTensor , torch_function )
71
71
72
+ _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS : Dict [str , Callable ] = {}
73
+
74
+ def register_aqt_storage_cls (storage_layout : str ):
75
+ def decorator (storage_cls ):
76
+ storage_cls .storage_layout = storage_layout
77
+ _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS [storage_layout ] = storage_cls
78
+ return storage_cls
79
+ return decorator
80
+
81
+ def get_aqt_storage_cls (storage_layout : str ) -> Callable :
82
+ if storage_layout not in _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS :
83
+ raise ValueError (f"storage layout: { storage_layout } is not supported yet" )
84
+ return _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS .get (storage_layout )
85
+
86
+ class AQTStorage (torch .Tensor ):
87
+ # this should be set for each storage class during registration
88
+ storage_layout : Optional [str ] = None
89
+
90
+ def __init__ (
91
+ self ,
92
+ int_data : torch .Tensor ,
93
+ scale : torch .Tensor ,
94
+ zero_point : torch .Tensor ,
95
+ ):
96
+ pass
97
+
98
+ def get_plain () -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
99
+ pass
100
+
101
+ @register_aqt_storage_cls ("plain" )
102
+ class PlainAQTStorage (AQTStorage ):
103
+ def __new__ (
104
+ cls ,
105
+ int_data : torch .Tensor ,
106
+ scale : torch .Tensor ,
107
+ zero_point : torch .Tensor ,
108
+ ):
109
+ kwargs = {}
110
+ kwargs ["device" ] = int_data .device
111
+ kwargs ["layout" ] = (
112
+ kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else int_data .layout
113
+ )
114
+ kwargs ["dtype" ] = int_data .dtype
115
+ kwargs ["requires_grad" ] = False
116
+ shape = int_data .shape
117
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
118
+
119
+ def __init__ (
120
+ self ,
121
+ int_data : torch .Tensor ,
122
+ scale : torch .Tensor ,
123
+ zero_point : torch .Tensor ,
124
+ ):
125
+ self .int_data = int_data
126
+ self .scale = scale
127
+ self .zero_point = zero_point
128
+
129
+ def __tensor_flatten__ (self ):
130
+ return ["int_data" , "scale" , "zero_point" ], []
131
+
132
+ @classmethod
133
+ def __tensor_unflatten__ (
134
+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
135
+ ):
136
+ int_data , scale , zero_point = tensor_data_dict ["int_data" ], tensor_data_dict ["scale" ], tensor_data_dict ["zero_point" ]
137
+ return cls (int_data , scale , zero_point )
138
+
139
+ # TODO: dedup
140
+ def _get_to_kwargs (self , * args , ** kwargs ):
141
+ device , dtype , _ , memory_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
142
+ device = self .device if device is None else device
143
+ dtype = self .dtype if dtype is None else dtype
144
+ memory_format = (
145
+ memory_format if memory_format is not None else torch .preserve_format
146
+ )
147
+ kwargs = {
148
+ "device" : device ,
149
+ "dtype" : dtype ,
150
+ "memory_format" : memory_format ,
151
+ }
152
+ return kwargs
153
+
154
+ def to (self , * args , ** kwargs ):
155
+ kwargs = self ._get_to_kwargs (* args , ** kwargs )
156
+ return self .__class__ (
157
+ self .int_data .to (kwargs ["device" ]),
158
+ self .scale .to (kwargs ["device" ]),
159
+ self .zero_point .to (kwargs ["device" ]),
160
+ )
161
+
162
+ def _apply_fn_to_data (self , fn ):
163
+ return self .__class__ (
164
+ fn (self .int_data ),
165
+ fn (self .scale ),
166
+ fn (self .zero_point ),
167
+ )
168
+
169
+ @classmethod
170
+ def __torch_dispatch__ (cls , func , types , args , kwargs ):
171
+ kwargs = {} if kwargs is None else kwargs
172
+
173
+ if func is aten .detach .default :
174
+ return return_and_correct_aliasing (
175
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
176
+ )
177
+
178
+ raise NotImplementedError (
179
+ f"PlainAQTStorage dispatch: attempting to run { func } , this is not supported"
180
+ )
181
+
182
+ __torch_function__ = torch ._C ._disabled_torch_function_impl
183
+
184
+ def get_plain (self ):
185
+ return self .int_data , self .scale , self .zero_point
186
+
72
187
73
188
class AffineQuantizedTensor (torch .Tensor ):
74
189
"""
@@ -103,9 +218,7 @@ class AffineQuantizedTensor(torch.Tensor):
103
218
@staticmethod
104
219
def __new__ (
105
220
cls ,
106
- int_data : torch .Tensor ,
107
- scale : torch .Tensor ,
108
- zero_point : torch .Tensor ,
221
+ storage_tensor : AQTStorage ,
109
222
block_size : Tuple [int , ...],
110
223
shape : torch .Size ,
111
224
quant_min : Optional [int ] = None ,
@@ -115,9 +228,9 @@ def __new__(
115
228
strides = None ,
116
229
):
117
230
kwargs = {}
118
- kwargs ["device" ] = int_data .device
231
+ kwargs ["device" ] = storage_tensor .device
119
232
kwargs ["layout" ] = (
120
- kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else int_data .layout
233
+ kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else storage_tensor .layout
121
234
)
122
235
if dtype is None :
123
236
dtype = scale .dtype
@@ -129,9 +242,7 @@ def __new__(
129
242
130
243
def __init__ (
131
244
self ,
132
- int_data : torch .Tensor ,
133
- scale : torch .Tensor ,
134
- zero_point : torch .Tensor ,
245
+ storage_tensor : AQTStorage ,
135
246
block_size : Tuple [int , ...],
136
247
shape : torch .Size ,
137
248
quant_min : Optional [int ] = None ,
@@ -140,9 +251,7 @@ def __init__(
140
251
dtype = None ,
141
252
strides = None ,
142
253
):
143
- self .int_data = int_data
144
- self .scale = scale
145
- self .zero_point = zero_point
254
+ self .storage_tensor = storage_tensor
146
255
self .block_size = block_size
147
256
self .quant_min = quant_min
148
257
self .quant_max = quant_max
@@ -157,21 +266,20 @@ def __repr__(self):
157
266
def dequantize (self , output_dtype = None ):
158
267
if output_dtype is None :
159
268
output_dtype = self .dtype
160
- return dequantize_affine (self .int_data , self .block_size , self .scale , self .zero_point , self .int_data .dtype , self .quant_min , self .quant_max , self .zero_point_domain , output_dtype = output_dtype )
269
+ int_data , scale , zero_point = self .storage_tensor .get_plain ()
270
+ return dequantize_affine (int_data , self .block_size , scale , zero_point , int_data .dtype , self .quant_min , self .quant_max , self .zero_point_domain , output_dtype = output_dtype )
161
271
162
272
def __tensor_flatten__ (self ):
163
- return ["int_data" , "scale" , "zero_point " ], [self .block_size , self .shape , self .quant_min , self .quant_max , self .zero_point_domain , self .dtype ]
273
+ return ["storage_tensor " ], [self .block_size , self .shape , self .quant_min , self .quant_max , self .zero_point_domain , self .dtype ]
164
274
165
275
@classmethod
166
276
def __tensor_unflatten__ (
167
277
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
168
278
):
169
- int_data , scale , zero_point = tensor_data_dict ["int_data" ], tensor_data_dict [ "scale" ], tensor_data_dict [ "zero_point " ]
279
+ storage_tensor = tensor_data_dict ["storage_tensor " ]
170
280
block_size , shape , quant_min , quant_max , zero_point_domain , dtype = tensor_attributes
171
281
return cls (
172
- int_data ,
173
- scale ,
174
- zero_point ,
282
+ storage_tensor ,
175
283
block_size ,
176
284
shape if outer_size is None else outer_size ,
177
285
quant_min ,
@@ -195,13 +303,15 @@ def from_float(
195
303
zero_point_dtype : Optional [torch .dtype ] = None ,
196
304
preserve_zero : bool = True ,
197
305
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
306
+ storage_layout : str = "plain" ,
198
307
):
199
308
scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , scale_dtype , zero_point_dtype , preserve_zero , zero_point_domain )
200
309
int_data = quantize_affine (input_float , block_size , scale , zero_point , target_dtype , quant_min , quant_max , zero_point_domain )
310
+
311
+ storage_cls = get_aqt_storage_cls (storage_layout )
312
+ storage_tensor = storage_cls (int_data , scale , zero_point )
201
313
return cls (
202
- int_data ,
203
- scale ,
204
- zero_point ,
314
+ storage_tensor ,
205
315
block_size ,
206
316
input_float .shape ,
207
317
quant_min ,
@@ -210,6 +320,10 @@ def from_float(
210
320
dtype = input_float .dtype
211
321
)
212
322
323
+ @property
324
+ def storage_layout (self ) -> str :
325
+ return self .storage_tensor .storage_layout
326
+
213
327
@classmethod
214
328
def __torch_function__ (cls , func , types , args = (), kwargs = None ):
215
329
kwargs = {} if kwargs is None else kwargs
@@ -238,9 +352,7 @@ def _get_to_kwargs(self, *args, **kwargs):
238
352
def to (self , * args , ** kwargs ):
239
353
kwargs = self ._get_to_kwargs (* args , ** kwargs )
240
354
return self .__class__ (
241
- self .int_data .to (kwargs ["device" ]),
242
- self .scale .to (kwargs ["device" ]),
243
- self .zero_point .to (kwargs ["device" ]),
355
+ self .storage_tensor .to (kwargs ["device" ]),
244
356
self .block_size ,
245
357
self .shape ,
246
358
self .quant_min ,
@@ -251,9 +363,7 @@ def to(self, *args, **kwargs):
251
363
252
364
def _apply_fn_to_data (self , fn ):
253
365
return self .__class__ (
254
- fn (self .int_data ),
255
- fn (self .scale ),
256
- fn (self .zero_point ),
366
+ fn (self .storage_tensor ),
257
367
self .block_size ,
258
368
self .shape ,
259
369
self .quant_min ,
@@ -308,7 +418,9 @@ def functional_linear(*args, **kwargs):
308
418
if (
309
419
is_cuda and
310
420
input_is_int8 and
311
- input_tensor_dtype_is_expected
421
+ input_tensor_dtype_is_expected and
422
+ input_tensor .storage_layout == "plain" and
423
+ weight_qtensor .storage_layout == "plain"
312
424
):
313
425
#
314
426
# 1. do the matrix form of dot(X_i, W_j)
@@ -321,10 +433,10 @@ def functional_linear(*args, **kwargs):
321
433
# value of a float 16, (which results in a value of inf even if multiplying
322
434
# by the other scale would bring it within the expected range)
323
435
324
- x_vals_int8 = input_tensor .int_data
325
- x_scales = input_tensor .scale
326
- w_vals_int8_t = weight_qtensor .int_data .contiguous ().t ()
327
- w_scales = weight_qtensor .scale
436
+ x_vals_int8 = input_tensor .storage_tensor . int_data
437
+ x_scales = input_tensor .storage_tensor . scale
438
+ w_vals_int8_t = weight_qtensor .storage_tensor . int_data .contiguous ().t ()
439
+ w_scales = weight_qtensor .storage_tensor . scale
328
440
tmp = x_vals_int8 .reshape (- 1 , x_vals_int8 .shape [- 1 ])
329
441
y_dot_scaled = int_scaled_matmul (tmp , w_vals_int8_t , x_scales .reshape (- 1 , 1 ))
330
442
@@ -344,22 +456,22 @@ def functional_linear(*args, **kwargs):
344
456
# weight only quantization
345
457
# TODO: enable cpu and mps path as well
346
458
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
347
- # TODO: move this to TinygemmAffineQuantizedTensor
348
459
if (
349
460
is_cuda and
350
461
weight_is_uint4 and
351
462
weight_qtensor .dtype == torch .bfloat16 and
352
463
len (weight_qtensor .shape ) == 2 and
353
464
weight_qtensor .block_size [0 ] == 1 and
354
- weight_qtensor .zero_point_domain == ZeroPointDomain .FLOAT
465
+ weight_qtensor .zero_point_domain == ZeroPointDomain .FLOAT and
466
+ weight_qtensor .storage_layout == "plain"
355
467
):
356
468
# groupwise int4 quantization
357
469
# TODO: currently doing packing on the fly, we'll need to figure out
358
470
# the API to do packing before hand
359
471
# TODO: expose the arg
360
472
innerKTiles = 8
361
- packed_weight = torch .ops .aten ._convert_weight_to_int4pack (weight_qtensor .int_data .to (torch .int32 ), innerKTiles )
362
- scales_and_zeros = pack_tinygemm_scales_and_zeros (weight_qtensor .scale , weight_qtensor .zero_point )
473
+ packed_weight = torch .ops .aten ._convert_weight_to_int4pack (weight_qtensor .storage_tensor . int_data .to (torch .int32 ), innerKTiles )
474
+ scales_and_zeros = pack_tinygemm_scales_and_zeros (weight_qtensor .storage_tensor . scale , weight_qtensor . storage_tensor .zero_point )
363
475
groupsize = weight_qtensor .block_size [- 1 ]
364
476
return torch .ops .aten ._weight_int4pack_mm (input_tensor .contiguous (), packed_weight , groupsize , scales_and_zeros )
365
477
elif (
@@ -368,11 +480,12 @@ def functional_linear(*args, **kwargs):
368
480
len (weight_qtensor .shape ) == 2 and
369
481
len (weight_qtensor .block_size ) == 2 and
370
482
weight_qtensor .block_size [0 ] == 1 and
371
- weight_qtensor .block_size [1 ] == weight_qtensor .shape [1 ]
483
+ weight_qtensor .block_size [1 ] == weight_qtensor .shape [1 ] and
484
+ weight_qtensor .storage_layout == "plain"
372
485
):
373
486
# TODO: enable mps path as well
374
487
# per channel int8 weight only quantizated mm
375
- return torch .ops .aten ._weight_int8pack_mm (input_tensor .contiguous (), weight_qtensor .int_data , weight_qtensor .scale )
488
+ return torch .ops .aten ._weight_int8pack_mm (input_tensor .contiguous (), weight_qtensor .storage_tensor . int_data , weight_qtensor . storage_tensor .scale )
376
489
else :
377
490
weight_tensor = weight_qtensor .dequantize ()
378
491
return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
0 commit comments