14
14
from torchao .quantization .quant_api import (
15
15
_replace_with_custom_fn_if_matches_filter ,
16
16
)
17
+ from torchao .quantization .quant_primitives import TorchAODType
18
+ from .api import FakeQuantizeConfig
19
+ from .fake_quantizer import FakeQuantizer
17
20
from .utils import (
18
21
_fake_quantize_per_channel_group ,
19
22
_get_qmin_qmax ,
20
23
)
21
24
22
25
26
+ class FakeQuantizedEmbedding (torch .nn .Embedding ):
27
+ """
28
+ General embedding layer with fake quantized weights.
29
+
30
+ Specific target dtypes, granularity, schemes etc. are specified
31
+ through separate configs for weights and activations.
32
+
33
+ Example usage::
34
+
35
+ weight_config = FakeQuantizeConfig(
36
+ dtype=torch.int4,
37
+ group_size=8,
38
+ symmetric=True,
39
+ )
40
+ fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
41
+ fq_embedding(torch.LongTensor([3]))
42
+ """
43
+
44
+ def __init__ (
45
+ self ,
46
+ num_embeddings : int ,
47
+ embedding_dim : int ,
48
+ padding_idx : Optional [int ] = None ,
49
+ max_norm : Optional [float ] = None ,
50
+ norm_type : float = 2.0 ,
51
+ scale_grad_by_freq : bool = False ,
52
+ sparse : bool = False ,
53
+ weight_config : Optional [FakeQuantizeConfig ] = None ,
54
+ * args ,
55
+ ** kwargs ,
56
+ ) -> None :
57
+ super ().__init__ (
58
+ num_embeddings ,
59
+ embedding_dim ,
60
+ padding_idx ,
61
+ max_norm ,
62
+ norm_type ,
63
+ scale_grad_by_freq ,
64
+ sparse ,
65
+ * args ,
66
+ ** kwargs ,
67
+ )
68
+ if weight_config is not None :
69
+ self .weight_fake_quantizer = FakeQuantizer (weight_config )
70
+ else :
71
+ self .weight_fake_quantizer = None
72
+
73
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
74
+ if self .weight_fake_quantizer is not None :
75
+ w = self .weight_fake_quantizer (self .weight )
76
+ else :
77
+ w = self .weight
78
+ return F .embedding (
79
+ x , w , self .padding_idx , self .max_norm ,
80
+ self .norm_type , self .scale_grad_by_freq , self .sparse ,
81
+ )
82
+
83
+
23
84
# ======================================
24
85
# | Embedding int4 weight-only QAT |
25
86
# ======================================
@@ -37,10 +98,9 @@ def __init__(
37
98
zero_point_precision : torch .dtype = torch .int32 ,
38
99
) -> None :
39
100
super ().__init__ ()
40
- self .bit_width = 4
41
101
self .group_size : int = group_size
42
102
self .scale_precision : torch .dtype = scale_precision
43
- self .zero_point_precision : torch .dtype = zero_point_precision ,
103
+ self .zero_point_precision : torch .dtype = zero_point_precision
44
104
45
105
def prepare (
46
106
self ,
@@ -56,16 +116,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
56
116
57
117
def replacement_fn (child : torch .nn .Module ) -> torch .nn .Module :
58
118
new_embedding = Int4WeightOnlyQATEmbedding (
59
- group_size = self .group_size ,
60
-
61
- # other nn.Embedding args
119
+ # nn.Embedding args
62
120
num_embeddings = child .num_embeddings ,
63
121
embedding_dim = child .embedding_dim ,
64
122
padding_idx = child .padding_idx ,
65
123
max_norm = child .max_norm ,
66
124
norm_type = child .norm_type ,
67
125
scale_grad_by_freq = child .scale_grad_by_freq ,
68
126
sparse = child .sparse ,
127
+ # quantization args
128
+ group_size = self .group_size ,
129
+ scale_precision = self .scale_precision ,
130
+ zero_point_precision = self .zero_point_precision ,
69
131
device = child .weight .device ,
70
132
)
71
133
# In distributed training, the model may be instantiated
@@ -98,28 +160,31 @@ def _convert_helper(self, module: torch.nn.Module):
98
160
from torchao ._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
99
161
for name , child in module .named_children ():
100
162
if isinstance (child , Int4WeightOnlyQATEmbedding ):
163
+ group_size = child .weight_fake_quantizer .config .group_size
164
+ scale_precision = child .weight_fake_quantizer .config .scale_precision
165
+ zero_point_precision = child .weight_fake_quantizer .config .zero_point_precision
101
166
quantized_embedding = Int4WeightOnlyEmbedding (
102
- group_size = child .group_size ,
103
- scale_precision = child .scale_precision ,
104
- zero_point_precision = child .zero_point_precision ,
105
-
106
- # other nn.Embedding args
167
+ # nn.Embedding args
107
168
num_embeddings = child .num_embeddings ,
108
169
embedding_dim = child .embedding_dim ,
109
170
padding_idx = child .padding_idx ,
110
171
max_norm = child .max_norm ,
111
172
norm_type = child .norm_type ,
112
173
scale_grad_by_freq = child .scale_grad_by_freq ,
113
174
sparse = child .sparse ,
175
+ # quantization args
176
+ group_size = group_size ,
177
+ scale_precision = scale_precision ,
178
+ zero_point_precision = zero_point_precision ,
114
179
device = child .weight .device ,
115
180
)
116
181
setattr (module , name , quantized_embedding )
117
182
118
183
# Load weights and qparams into quantized embedding
119
- (qmin , qmax ) = _get_qmin_qmax (self . bit_width )
120
- (s , zp ) = get_group_qparams_symmetric (child .weight , self . bit_width , child . group_size )
184
+ (qmin , qmax ) = _get_qmin_qmax (4 )
185
+ (s , zp ) = get_group_qparams_symmetric (child .weight , 4 , group_size )
121
186
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
122
- child .weight , s , zp , qmin , qmax , torch .int8 , child . group_size ,
187
+ child .weight , s , zp , qmin , qmax , torch .int8 , group_size ,
123
188
)
124
189
quantized_embedding .weight = q_weight
125
190
quantized_embedding .scales = s
@@ -128,7 +193,7 @@ def _convert_helper(self, module: torch.nn.Module):
128
193
self ._convert_helper (child )
129
194
130
195
131
- class Int4WeightOnlyQATEmbedding (torch . nn . Embedding ):
196
+ class Int4WeightOnlyQATEmbedding (FakeQuantizedEmbedding ):
132
197
"""
133
198
This module implements a embedding layer with int4 fake quantized
134
199
grouped per channel weights.
@@ -141,47 +206,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
141
206
142
207
def __init__ (
143
208
self ,
209
+ num_embeddings : int ,
210
+ embedding_dim : int ,
211
+ padding_idx : Optional [int ] = None ,
212
+ max_norm : Optional [float ] = None ,
213
+ norm_type : float = 2.0 ,
214
+ scale_grad_by_freq : bool = False ,
215
+ sparse : bool = False ,
144
216
group_size : int = 32 ,
145
217
scale_precision : torch .dtype = torch .float32 ,
146
218
zero_point_precision : torch .dtype = torch .int32 ,
147
219
* args ,
148
220
** kwargs ,
149
221
):
150
- super ().__init__ (* args , ** kwargs )
151
- self .bit_width = 4
152
- self .group_size = group_size
153
- self .scale_precision = scale_precision
154
- self .zero_point_precision = zero_point_precision
155
- self ._fake_quant_enabled = True
156
-
157
- def forward (self , x ):
158
- weight = self .weight
159
-
160
- if self ._fake_quant_enabled :
161
- (weight_scales , weight_zp ) = get_group_qparams_symmetric (
162
- self .weight , self .bit_width , self .group_size , self .scale_precision ,
163
- )
164
- # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
165
- weight_zp = weight_zp .to (self .zero_point_precision )
166
- (weight_qmin , weight_qmax ) = _get_qmin_qmax (self .bit_width )
167
- w_fq = _fake_quantize_per_channel_group (
168
- self .weight ,
169
- weight_scales ,
170
- weight_zp ,
171
- weight_qmin ,
172
- weight_qmax ,
173
- self .group_size ,
174
- )
175
- else :
176
- w_fq = self .weight
177
-
178
- return F .embedding (
179
- x , w_fq , self .padding_idx , self .max_norm ,
180
- self .norm_type , self .scale_grad_by_freq , self .sparse ,
222
+ weight_config = FakeQuantizeConfig (
223
+ dtype = TorchAODType .INT4 ,
224
+ group_size = group_size ,
225
+ is_symmetric = True ,
226
+ is_dynamic = True ,
227
+ scale_precision = scale_precision ,
228
+ zero_point_precision = zero_point_precision ,
229
+ )
230
+ super ().__init__ (
231
+ num_embeddings ,
232
+ embedding_dim ,
233
+ padding_idx ,
234
+ max_norm ,
235
+ norm_type ,
236
+ scale_grad_by_freq ,
237
+ sparse ,
238
+ weight_config ,
239
+ * args ,
240
+ ** kwargs ,
181
241
)
182
242
183
243
def enable_fake_quant (self , enabled : bool = True ):
184
- self ._fake_quant_enabled = enabled
244
+ self .weight_fake_quantizer . enabled = enabled
185
245
186
246
def disable_fake_quant (self ):
187
247
self .enable_fake_quant (False )
@@ -194,25 +254,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
194
254
"""
195
255
def __init__ (
196
256
self ,
197
- group_size : int ,
198
- scale_precision : torch .dtype ,
199
- zero_point_precision : torch .dtype ,
200
-
201
- # nn.Embedding args
202
257
num_embeddings : int ,
203
258
embedding_dim : int ,
204
259
padding_idx : Optional [int ] = None ,
205
260
max_norm : Optional [float ] = None ,
206
261
norm_type : float = 2.0 ,
207
262
scale_grad_by_freq : bool = False ,
208
263
sparse : bool = False ,
264
+ group_size : int = 32 ,
265
+ scale_precision : torch .dtype = torch .float32 ,
266
+ zero_point_precision : torch .dtype = torch .int32 ,
209
267
device : torch .device = None ,
210
268
):
211
269
super ().__init__ ()
212
- self .bit_width = 4
213
- self .group_size = group_size
214
- self .scale_precision = scale_precision
215
- self .zero_point_precision = zero_point_precision
270
+
271
+ # nn.Embedding args
216
272
self .num_embeddings = num_embeddings
217
273
self .embedding_dim = embedding_dim
218
274
self .padding_idx = padding_idx
@@ -221,6 +277,11 @@ def __init__(
221
277
self .scale_grad_by_freq = scale_grad_by_freq
222
278
self .sparse = sparse
223
279
280
+ # quantization args
281
+ self .group_size = group_size
282
+ self .scale_precision = scale_precision
283
+ self .zero_point_precision = zero_point_precision
284
+
224
285
# currently storing unpacked int8 weights
225
286
self .register_buffer (
226
287
"weight" ,
@@ -245,7 +306,7 @@ def __init__(
245
306
246
307
def forward (self , x ):
247
308
from torchao ._executorch_ops import _quantized_decomposed_dequantize_per_channel_group_wrapper
248
- qmin , qmax = _get_qmin_qmax (self . bit_width )
309
+ qmin , qmax = _get_qmin_qmax (4 )
249
310
w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper (
250
311
self .weight ,
251
312
self .scale ,
0 commit comments