@@ -124,6 +124,10 @@ def dynamically_quantize_per_channel(
124
124
return quant , scales , zero_points
125
125
126
126
127
+ #########################################################################
128
+ ### QuantHandler API definition ###
129
+
130
+
127
131
class QuantHandler :
128
132
def __init__ (self , mod ):
129
133
self .mod = mod
@@ -134,8 +138,15 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict"
134
138
def convert_for_runtime (self ) -> nn .Module :
135
139
pass
136
140
141
+ def quantized_model (self ) -> nn .Module :
142
+ model_updated_state_dict = self .create_quantized_state_dict ()
143
+ self .convert_for_runtime ()
144
+ self .mod .load_state_dict (model_updated_state_dict )
145
+ return self .mod
137
146
138
- ##### Weight-only int8 per-channel quantized code ######
147
+
148
+ #########################################################################
149
+ ### Weight-only int8 per-channel quantized code ###
139
150
140
151
141
152
def replace_linear_weight_only_int8_per_channel (module , node_type ):
@@ -153,16 +164,17 @@ def replace_linear_weight_only_int8_per_channel(module, node_type):
153
164
setattr (
154
165
module ,
155
166
name ,
156
- WeightOnlyInt8Linear (child .in_features , child .out_features ),
167
+ WeightOnlyInt8Linear ("cpu" , child .in_features , child .out_features ),
157
168
)
158
169
else :
159
170
replace_linear_weight_only_int8_per_channel (child , node_type )
160
171
161
172
162
- class WeightOnlyInt8QuantHandler :
173
+ class WeightOnlyInt8QuantHandler ( QuantHandler ) :
163
174
def __init__ (
164
175
self ,
165
176
mod ,
177
+ device = "cpu" ,
166
178
* ,
167
179
node_type : str = "*" ,
168
180
bitwidth : Optional [int ] = None ,
@@ -202,7 +214,7 @@ def create_quantized_state_dict(self) -> Dict:
202
214
)
203
215
):
204
216
print (
205
- f"quantize { self .node_type } { fqn , mod } with groupsize { self .group_size } , bitwidth { self .bitwidth } "
217
+ f"quantize { self .node_type } { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
206
218
)
207
219
208
220
# print(f"initial weight shape {mod.weight.shape}")
@@ -219,7 +231,7 @@ def create_quantized_state_dict(self) -> Dict:
219
231
)
220
232
221
233
cur_state_dict [f"{ fqn } .weight" ] = weight
222
- # squeeze makes groupsize =rowsize unidimensional
234
+ # squeeze makes group_size =rowsize unidimensional
223
235
cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
224
236
225
237
return cur_state_dict
@@ -243,10 +255,10 @@ class WeightOnlyInt8Linear(torch.nn.Module):
243
255
244
256
def __init__ (
245
257
self ,
258
+ device ,
246
259
in_features : int ,
247
260
out_features : int ,
248
261
bias : bool = True ,
249
- device = None ,
250
262
dtype = None ,
251
263
) -> None :
252
264
super ().__init__ ()
@@ -262,11 +274,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
262
274
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
263
275
264
276
265
- ##### embedding table quantization ######
277
+ #########################################################################
278
+ ##### embedding table quantization ######
266
279
267
280
268
281
def replace_embedding_weight_only_grouped_int8_per_channel (
269
- module , bitwidth : int = 8 , group_size : Optional [int ] = None
282
+ module , device , bitwidth : int = 8 , group_size : Optional [int ] = None , packed = False
270
283
):
271
284
for name , child in module .named_children ():
272
285
# print(f"name: {name}")
@@ -277,25 +290,41 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
277
290
module ,
278
291
name ,
279
292
QuantizedGroupEmbedding (
293
+ device = device ,
280
294
vocab_size = child .weight .shape [0 ],
281
295
embedding_dim = child .weight .shape [1 ],
282
296
group_size = group_size ,
297
+ packed = packed ,
283
298
),
284
299
)
285
300
else :
286
301
replace_embedding_weight_only_grouped_int8_per_channel (
287
- child , bitwidth , group_size
302
+ child , device , bitwidth , group_size , packed
288
303
)
289
304
290
305
291
- class EmbeddingOnlyInt8QuantHandler :
292
- def __init__ (self , mod , * , bitwidth : int = 8 , group_size : Optional [int ] = None ):
306
+ class EmbeddingQuantHandler (QuantHandler ):
307
+ def __init__ (
308
+ self ,
309
+ mod ,
310
+ device = "cpu" ,
311
+ * ,
312
+ bitwidth : int = 8 ,
313
+ group_size : Optional [int ] = None ,
314
+ packed = False ,
315
+ ):
316
+ if isinstance (packed , str ):
317
+ packed = packed == "True"
293
318
self .mod = mod
319
+ self .device = device
294
320
self .group_size = group_size
295
321
self .bitwidth = bitwidth
322
+ self .packed = packed
323
+ if (bitwidth != 4 ) and packed :
324
+ raise RuntimeError ("pack only works with bitsize 4" )
296
325
297
326
@torch .no_grad ()
298
- def create_quantized_state_dict (self ) -> Dict :
327
+ def create_quantized_state_dict (self , packed = False ) -> Dict :
299
328
cur_state_dict = self .mod .state_dict ()
300
329
301
330
if self .bitwidth == 4 :
@@ -308,18 +337,14 @@ def create_quantized_state_dict(self) -> Dict:
308
337
raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
309
338
310
339
for fqn , mod in self .mod .named_modules ():
311
- if (
312
- isinstance (mod , nn .Embedding )
313
- or isinstance (mod , fsEmbedding )
314
- or isinstance (mod , fsStandardEmbedding )
315
- ):
340
+ if isinstance (mod , nn .Embedding ):
316
341
# print("****")
317
342
# print(f"Embedding identified: {fqn, mod}")
318
343
# print(f"weights size: {mod.weight.size()}")
319
344
# print(f"quantize {fqn}...")
320
345
321
346
print (
322
- f"quantize { fqn , mod } with groupsize { self .group_size } , bitwidth { self .bitwidth } "
347
+ f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
323
348
)
324
349
weight , scales , _ = dynamically_quantize_per_channel (
325
350
mod .weight .float (),
@@ -330,21 +355,36 @@ def create_quantized_state_dict(self) -> Dict:
330
355
scales_dtype = mod .weight .dtype ,
331
356
)
332
357
358
+ if packed :
359
+ if weight .shape [- 1 ] % 2 != 0 :
360
+ raise RuntimeError ("automatic padding not implemented yet" )
361
+
362
+ weight_range_shifted = weight .add (8 ).view (torch .uint8 )
363
+ weight_view = weight_range_shifted .view (
364
+ weight .shape [0 ], weight .shape [1 ] // 2 , 2
365
+ )
366
+ weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
367
+ weight_odd = weight_view [:, :, 1 ]
368
+ weight_packed = weight_even + weight_odd
369
+ weight = weight_packed
370
+
371
+ weight = weight .to (device = self .device )
372
+ scales = scales .to (device = self .device )
333
373
# Update state dict
334
374
cur_state_dict [f"{ fqn } .weight" ] = weight
335
- # squeeze makes groupsize =rowsize unidimensional
375
+ # squeeze makes group_size =rowsize unidimensional
336
376
cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
337
377
338
378
return cur_state_dict
339
379
340
380
def convert_for_runtime (self ) -> nn .Module :
341
381
replace_embedding_weight_only_grouped_int8_per_channel (
342
- self .mod , self .bitwidth , self .group_size
382
+ self .mod , self .device , self . bitwidth , self .group_size , self . packed
343
383
)
344
384
return self .mod
345
385
346
386
def quantized_model (self ) -> nn .Module :
347
- model_updated_state_dict = self .create_quantized_state_dict ()
387
+ model_updated_state_dict = self .create_quantized_state_dict (self . packed )
348
388
self .convert_for_runtime ()
349
389
self .mod .load_state_dict (model_updated_state_dict )
350
390
return self .mod
@@ -353,39 +393,53 @@ def quantized_model(self) -> nn.Module:
353
393
class QuantizedGroupEmbedding (torch .nn .Module ):
354
394
def __init__ (
355
395
self ,
396
+ device ,
356
397
vocab_size : int ,
357
398
embedding_dim : int ,
358
399
group_size : Optional [int ] = None ,
359
- device = None ,
360
400
dtype = torch .half ,
401
+ packed = False ,
361
402
) -> None :
362
403
super ().__init__ ()
363
- if group_size is None :
404
+ if group_size is None or group_size == 0 :
364
405
group_size = embedding_dim
365
406
self .group_size = group_size
366
407
self .dtype = dtype
367
- self .register_buffer (
368
- "weight" , torch .empty ((vocab_size , embedding_dim ), dtype = torch .int8 )
369
- )
408
+ self .packed = packed
409
+ if not packed :
410
+ self .register_buffer (
411
+ "weight" ,
412
+ torch .empty (
413
+ (vocab_size , embedding_dim ), dtype = torch .int8 , device = device
414
+ ),
415
+ )
416
+ else : # packed
417
+ self .register_buffer (
418
+ "weight" ,
419
+ torch .empty (
420
+ (vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
421
+ ),
422
+ )
370
423
groups_per_row = (embedding_dim + group_size - 1 ) // group_size
371
424
if groups_per_row > 1 :
372
425
self .register_buffer (
373
- "scales" , torch .ones ((vocab_size , groups_per_row ), dtype = torch .float16 )
426
+ "scales" ,
427
+ torch .ones (
428
+ (vocab_size , groups_per_row ), dtype = torch .float16 , device = device
429
+ ),
374
430
)
375
431
else :
376
432
self .register_buffer (
377
- "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 )
433
+ "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 , device = device )
378
434
)
379
435
380
436
@torch .no_grad ()
381
437
def forward (self , indices : torch .Tensor ) -> torch .Tensor :
382
- return torch .ops .quantized_decomposed .embedding_byte .dtype (
383
- self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
384
- )
385
-
386
-
387
- # result_weights = self.weight.index_select(0, indices.view(-1))
388
- # result_scales = self.scales.index_select(0, indices.view(-1))
389
- #
390
- # r = result_weights.to(dtype=result_scales.dtype) * result_scales
391
- # return r.view(indices.size() + (-1,))
438
+ if not self .packed : # 8bit
439
+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
440
+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
441
+ )
442
+ else : # 4bit packed
443
+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
444
+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
445
+ )
0 commit comments