14
14
from torchao .quantization .quant_api import (
15
15
_replace_with_custom_fn_if_matches_filter ,
16
16
)
17
+ from torchao .quantization .quant_primitives import TorchAODType
18
+ from .api import FakeQuantizeConfig
19
+ from .fake_quantizer import FakeQuantizer
17
20
from .utils import (
18
21
_fake_quantize_per_channel_group ,
19
22
_get_qmin_qmax ,
20
23
)
21
24
22
25
26
+ class FakeQuantizedEmbedding (torch .nn .Embedding ):
27
+ """
28
+ General embedding layer with fake quantized weights.
29
+
30
+ Specific target dtypes, granularity, schemes etc. are specified
31
+ through separate configs for weights and activations.
32
+
33
+ Example usage::
34
+
35
+ weight_config = FakeQuantizeConfig(
36
+ dtype=torch.int4,
37
+ group_size=8,
38
+ symmetric=True,
39
+ )
40
+ fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
41
+ fq_embedding(torch.LongTensor([3]))
42
+ """
43
+
44
+ def __init__ (
45
+ self ,
46
+ num_embeddings : int ,
47
+ embedding_dim : int ,
48
+ padding_idx : Optional [int ] = None ,
49
+ max_norm : Optional [float ] = None ,
50
+ norm_type : float = 2.0 ,
51
+ scale_grad_by_freq : bool = False ,
52
+ sparse : bool = False ,
53
+ weight_config : Optional [FakeQuantizeConfig ] = None ,
54
+ * args ,
55
+ ** kwargs ,
56
+ ) -> None :
57
+ super ().__init__ (
58
+ num_embeddings ,
59
+ embedding_dim ,
60
+ padding_idx ,
61
+ max_norm ,
62
+ norm_type ,
63
+ scale_grad_by_freq ,
64
+ sparse ,
65
+ * args ,
66
+ ** kwargs ,
67
+ )
68
+ if weight_config is not None :
69
+ self .weight_fake_quantizer = FakeQuantizer (weight_config )
70
+ else :
71
+ self .weight_fake_quantizer = None
72
+
73
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
74
+ if self .weight_fake_quantizer is not None :
75
+ w = self .weight_fake_quantizer (self .weight )
76
+ else :
77
+ w = self .weight
78
+ return F .embedding (
79
+ x , w , self .padding_idx , self .max_norm ,
80
+ self .norm_type , self .scale_grad_by_freq , self .sparse ,
81
+ )
82
+
83
+
23
84
# ======================================
24
85
# | Embedding int4 weight-only QAT |
25
86
# ======================================
@@ -40,7 +101,7 @@ def __init__(
40
101
self .bit_width = 4
41
102
self .group_size : int = group_size
42
103
self .scale_precision : torch .dtype = scale_precision
43
- self .zero_point_precision : torch .dtype = zero_point_precision ,
104
+ self .zero_point_precision : torch .dtype = zero_point_precision
44
105
45
106
def prepare (
46
107
self ,
@@ -56,16 +117,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
56
117
57
118
def replacement_fn (child : torch .nn .Module ) -> torch .nn .Module :
58
119
new_embedding = Int4WeightOnlyQATEmbedding (
59
- group_size = self .group_size ,
60
-
61
- # other nn.Embedding args
120
+ # nn.Embedding args
62
121
num_embeddings = child .num_embeddings ,
63
122
embedding_dim = child .embedding_dim ,
64
123
padding_idx = child .padding_idx ,
65
124
max_norm = child .max_norm ,
66
125
norm_type = child .norm_type ,
67
126
scale_grad_by_freq = child .scale_grad_by_freq ,
68
127
sparse = child .sparse ,
128
+ # quantization args
129
+ group_size = self .group_size ,
130
+ scale_precision = self .scale_precision ,
131
+ zero_point_precision = self .zero_point_precision ,
69
132
device = child .weight .device ,
70
133
)
71
134
# In distributed training, the model may be instantiated
@@ -98,28 +161,31 @@ def _convert_helper(self, module: torch.nn.Module):
98
161
from torchao ._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
99
162
for name , child in module .named_children ():
100
163
if isinstance (child , Int4WeightOnlyQATEmbedding ):
164
+ group_size = child .weight_fake_quantizer .config .group_size
165
+ scale_precision = child .weight_fake_quantizer .config .scale_precision
166
+ zero_point_precision = child .weight_fake_quantizer .config .zero_point_precision
101
167
quantized_embedding = Int4WeightOnlyEmbedding (
102
- group_size = child .group_size ,
103
- scale_precision = child .scale_precision ,
104
- zero_point_precision = child .zero_point_precision ,
105
-
106
- # other nn.Embedding args
168
+ # nn.Embedding args
107
169
num_embeddings = child .num_embeddings ,
108
170
embedding_dim = child .embedding_dim ,
109
171
padding_idx = child .padding_idx ,
110
172
max_norm = child .max_norm ,
111
173
norm_type = child .norm_type ,
112
174
scale_grad_by_freq = child .scale_grad_by_freq ,
113
175
sparse = child .sparse ,
176
+ # quantization args
177
+ group_size = group_size ,
178
+ scale_precision = scale_precision ,
179
+ zero_point_precision = zero_point_precision ,
114
180
device = child .weight .device ,
115
181
)
116
182
setattr (module , name , quantized_embedding )
117
183
118
184
# Load weights and qparams into quantized embedding
119
185
(qmin , qmax ) = _get_qmin_qmax (self .bit_width )
120
- (s , zp ) = get_group_qparams_symmetric (child .weight , self .bit_width , child . group_size )
186
+ (s , zp ) = get_group_qparams_symmetric (child .weight , self .bit_width , group_size )
121
187
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
122
- child .weight , s , zp , qmin , qmax , torch .int8 , child . group_size ,
188
+ child .weight , s , zp , qmin , qmax , torch .int8 , group_size ,
123
189
)
124
190
quantized_embedding .weight = q_weight
125
191
quantized_embedding .scales = s
@@ -128,7 +194,7 @@ def _convert_helper(self, module: torch.nn.Module):
128
194
self ._convert_helper (child )
129
195
130
196
131
- class Int4WeightOnlyQATEmbedding (torch . nn . Embedding ):
197
+ class Int4WeightOnlyQATEmbedding (FakeQuantizedEmbedding ):
132
198
"""
133
199
This module implements a embedding layer with int4 fake quantized
134
200
grouped per channel weights.
@@ -141,47 +207,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
141
207
142
208
def __init__ (
143
209
self ,
210
+ num_embeddings : int ,
211
+ embedding_dim : int ,
212
+ padding_idx : Optional [int ] = None ,
213
+ max_norm : Optional [float ] = None ,
214
+ norm_type : float = 2.0 ,
215
+ scale_grad_by_freq : bool = False ,
216
+ sparse : bool = False ,
144
217
group_size : int = 32 ,
145
218
scale_precision : torch .dtype = torch .float32 ,
146
219
zero_point_precision : torch .dtype = torch .int32 ,
147
220
* args ,
148
221
** kwargs ,
149
222
):
150
- super ().__init__ (* args , ** kwargs )
151
- self .bit_width = 4
152
- self .group_size = group_size
153
- self .scale_precision = scale_precision
154
- self .zero_point_precision = zero_point_precision
155
- self ._fake_quant_enabled = True
156
-
157
- def forward (self , x ):
158
- weight = self .weight
159
-
160
- if self ._fake_quant_enabled :
161
- (weight_scales , weight_zp ) = get_group_qparams_symmetric (
162
- self .weight , self .bit_width , self .group_size , self .scale_precision ,
163
- )
164
- # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
165
- weight_zp = weight_zp .to (self .zero_point_precision )
166
- (weight_qmin , weight_qmax ) = _get_qmin_qmax (self .bit_width )
167
- w_fq = _fake_quantize_per_channel_group (
168
- self .weight ,
169
- weight_scales ,
170
- weight_zp ,
171
- weight_qmin ,
172
- weight_qmax ,
173
- self .group_size ,
174
- )
175
- else :
176
- w_fq = self .weight
177
-
178
- return F .embedding (
179
- x , w_fq , self .padding_idx , self .max_norm ,
180
- self .norm_type , self .scale_grad_by_freq , self .sparse ,
223
+ weight_config = FakeQuantizeConfig (
224
+ dtype = TorchAODType .INT4 ,
225
+ group_size = group_size ,
226
+ is_symmetric = True ,
227
+ is_dynamic = True ,
228
+ scale_precision = scale_precision ,
229
+ zero_point_precision = zero_point_precision ,
230
+ )
231
+ super ().__init__ (
232
+ num_embeddings ,
233
+ embedding_dim ,
234
+ padding_idx ,
235
+ max_norm ,
236
+ norm_type ,
237
+ scale_grad_by_freq ,
238
+ sparse ,
239
+ weight_config ,
240
+ * args ,
241
+ ** kwargs ,
181
242
)
182
243
183
244
def enable_fake_quant (self , enabled : bool = True ):
184
- self ._fake_quant_enabled = enabled
245
+ self .weight_fake_quantizer . enabled = enabled
185
246
186
247
def disable_fake_quant (self ):
187
248
self .enable_fake_quant (False )
@@ -194,25 +255,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
194
255
"""
195
256
def __init__ (
196
257
self ,
197
- group_size : int ,
198
- scale_precision : torch .dtype ,
199
- zero_point_precision : torch .dtype ,
200
-
201
- # nn.Embedding args
202
258
num_embeddings : int ,
203
259
embedding_dim : int ,
204
260
padding_idx : Optional [int ] = None ,
205
261
max_norm : Optional [float ] = None ,
206
262
norm_type : float = 2.0 ,
207
263
scale_grad_by_freq : bool = False ,
208
264
sparse : bool = False ,
265
+ group_size : int = 32 ,
266
+ scale_precision : torch .dtype = torch .float32 ,
267
+ zero_point_precision : torch .dtype = torch .int32 ,
209
268
device : torch .device = None ,
210
269
):
211
270
super ().__init__ ()
212
- self .bit_width = 4
213
- self .group_size = group_size
214
- self .scale_precision = scale_precision
215
- self .zero_point_precision = zero_point_precision
271
+
272
+ # nn.Embedding args
216
273
self .num_embeddings = num_embeddings
217
274
self .embedding_dim = embedding_dim
218
275
self .padding_idx = padding_idx
@@ -221,6 +278,12 @@ def __init__(
221
278
self .scale_grad_by_freq = scale_grad_by_freq
222
279
self .sparse = sparse
223
280
281
+ # quantization args
282
+ self .bit_width = 4
283
+ self .group_size = group_size
284
+ self .scale_precision = scale_precision
285
+ self .zero_point_precision = zero_point_precision
286
+
224
287
# currently storing unpacked int8 weights
225
288
self .register_buffer (
226
289
"weight" ,
0 commit comments