Skip to content

Commit e88f00e

Browse files
committed
Add generic fake quantized embedding for QAT
Summary: This is equivalent to #1020 but for nn.Embedding. This commit adds a generic fake quantized embedding module to replace the uses of the existing more specific QAT embeddings. For example, `Int4WeightOnlyQATEmbedding` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.embedding import FakeQuantizedEmbedding weight_config = FakeQuantizeConfig( dtype=torch.int4, group_size=group_size, is_symmetric=True, ) fq_embedding = FakeQuantizedEmbedding(16, 32, weight_config=weight_config) ``` Test Plan: python test/quantization/test_qat.py -k test_qat_4w_embedding python test/quantization/test_qat.py -k test_fake_quantized_embedding_4w
1 parent 48bc81c commit e88f00e

File tree

2 files changed

+154
-56
lines changed

2 files changed

+154
-56
lines changed

test/quantization/test_qat.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from torchao.quantization.prototype.qat.fake_quantizer import (
3030
FakeQuantizer,
3131
)
32+
from torchao.quantization.prototype.qat.embedding import (
33+
FakeQuantizedEmbedding,
34+
)
3235
from torchao.quantization.prototype.qat.linear import (
3336
FakeQuantizedLinear,
3437
)
@@ -852,6 +855,40 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
852855
baseline_out = linear_forward_4w(x2, fq_linear.weight)
853856
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
854857

858+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
859+
def test_fake_quantized_embedding_4w(self):
860+
"""
861+
Test that we can express int4 per group symmetric weight only fake quantization
862+
with `FakeQuantizedEmbedding`.
863+
"""
864+
num_embeddings = 64
865+
embedding_dim = 128
866+
group_size = 32
867+
torch.manual_seed(self.SEED)
868+
fq_embedding = FakeQuantizedEmbedding(
869+
num_embeddings,
870+
embedding_dim,
871+
weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size),
872+
)
873+
874+
def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
875+
"""
876+
Baseline for int4 per group symmetric weight only fake quantization.
877+
"""
878+
(s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32)
879+
zp = zp.to(torch.int32)
880+
(qmin, qmax) = _get_qmin_qmax(4)
881+
w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size)
882+
return F.embedding(x, w_fq)
883+
884+
# Compare embedding values
885+
torch.manual_seed(self.SEED)
886+
x = torch.randint(num_embeddings, (5, 10))
887+
x2 = copy.deepcopy(x)
888+
fq_out = fq_embedding(x)
889+
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
890+
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
891+
855892

856893
if __name__ == "__main__":
857894
unittest.main()

torchao/quantization/prototype/qat/embedding.py

Lines changed: 117 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,73 @@
1414
from torchao.quantization.quant_api import (
1515
_replace_with_custom_fn_if_matches_filter,
1616
)
17+
from torchao.quantization.quant_primitives import TorchAODType
18+
from .api import FakeQuantizeConfig
19+
from .fake_quantizer import FakeQuantizer
1720
from .utils import (
1821
_fake_quantize_per_channel_group,
1922
_get_qmin_qmax,
2023
)
2124

2225

26+
class FakeQuantizedEmbedding(torch.nn.Embedding):
27+
"""
28+
General embedding layer with fake quantized weights.
29+
30+
Specific target dtypes, granularity, schemes etc. are specified
31+
through separate configs for weights and activations.
32+
33+
Example usage::
34+
35+
weight_config = FakeQuantizeConfig(
36+
dtype=torch.int4,
37+
group_size=8,
38+
symmetric=True,
39+
)
40+
fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
41+
fq_embedding(torch.LongTensor([3]))
42+
"""
43+
44+
def __init__(
45+
self,
46+
num_embeddings: int,
47+
embedding_dim: int,
48+
padding_idx: Optional[int] = None,
49+
max_norm: Optional[float] = None,
50+
norm_type: float = 2.0,
51+
scale_grad_by_freq: bool = False,
52+
sparse: bool = False,
53+
weight_config: Optional[FakeQuantizeConfig] = None,
54+
*args,
55+
**kwargs,
56+
) -> None:
57+
super().__init__(
58+
num_embeddings,
59+
embedding_dim,
60+
padding_idx,
61+
max_norm,
62+
norm_type,
63+
scale_grad_by_freq,
64+
sparse,
65+
*args,
66+
**kwargs,
67+
)
68+
if weight_config is not None:
69+
self.weight_fake_quantizer = FakeQuantizer(weight_config)
70+
else:
71+
self.weight_fake_quantizer = None
72+
73+
def forward(self, x: torch.Tensor) -> torch.Tensor:
74+
if self.weight_fake_quantizer is not None:
75+
w = self.weight_fake_quantizer(self.weight)
76+
else:
77+
w = self.weight
78+
return F.embedding(
79+
x, w, self.padding_idx, self.max_norm,
80+
self.norm_type, self.scale_grad_by_freq, self.sparse,
81+
)
82+
83+
2384
# ======================================
2485
# | Embedding int4 weight-only QAT |
2586
# ======================================
@@ -37,10 +98,9 @@ def __init__(
3798
zero_point_precision: torch.dtype = torch.int32,
3899
) -> None:
39100
super().__init__()
40-
self.bit_width = 4
41101
self.group_size: int = group_size
42102
self.scale_precision: torch.dtype = scale_precision
43-
self.zero_point_precision: torch.dtype = zero_point_precision,
103+
self.zero_point_precision: torch.dtype = zero_point_precision
44104

45105
def prepare(
46106
self,
@@ -56,16 +116,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
56116

57117
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
58118
new_embedding = Int4WeightOnlyQATEmbedding(
59-
group_size=self.group_size,
60-
61-
# other nn.Embedding args
119+
# nn.Embedding args
62120
num_embeddings=child.num_embeddings,
63121
embedding_dim=child.embedding_dim,
64122
padding_idx=child.padding_idx,
65123
max_norm=child.max_norm,
66124
norm_type=child.norm_type,
67125
scale_grad_by_freq=child.scale_grad_by_freq,
68126
sparse=child.sparse,
127+
# quantization args
128+
group_size=self.group_size,
129+
scale_precision=self.scale_precision,
130+
zero_point_precision=self.zero_point_precision,
69131
device=child.weight.device,
70132
)
71133
# In distributed training, the model may be instantiated
@@ -98,28 +160,31 @@ def _convert_helper(self, module: torch.nn.Module):
98160
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
99161
for name, child in module.named_children():
100162
if isinstance(child, Int4WeightOnlyQATEmbedding):
163+
group_size = child.weight_fake_quantizer.config.group_size
164+
scale_precision = child.weight_fake_quantizer.config.scale_precision
165+
zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision
101166
quantized_embedding = Int4WeightOnlyEmbedding(
102-
group_size=child.group_size,
103-
scale_precision=child.scale_precision,
104-
zero_point_precision=child.zero_point_precision,
105-
106-
# other nn.Embedding args
167+
# nn.Embedding args
107168
num_embeddings=child.num_embeddings,
108169
embedding_dim=child.embedding_dim,
109170
padding_idx=child.padding_idx,
110171
max_norm=child.max_norm,
111172
norm_type=child.norm_type,
112173
scale_grad_by_freq=child.scale_grad_by_freq,
113174
sparse=child.sparse,
175+
# quantization args
176+
group_size=group_size,
177+
scale_precision=scale_precision,
178+
zero_point_precision=zero_point_precision,
114179
device=child.weight.device,
115180
)
116181
setattr(module, name, quantized_embedding)
117182

118183
# Load weights and qparams into quantized embedding
119-
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
120-
(s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, child.group_size)
184+
(qmin, qmax) = _get_qmin_qmax(4)
185+
(s, zp) = get_group_qparams_symmetric(child.weight, 4, group_size)
121186
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
122-
child.weight, s, zp, qmin, qmax, torch.int8, child.group_size,
187+
child.weight, s, zp, qmin, qmax, torch.int8, group_size,
123188
)
124189
quantized_embedding.weight = q_weight
125190
quantized_embedding.scales = s
@@ -128,7 +193,7 @@ def _convert_helper(self, module: torch.nn.Module):
128193
self._convert_helper(child)
129194

130195

131-
class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
196+
class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
132197
"""
133198
This module implements a embedding layer with int4 fake quantized
134199
grouped per channel weights.
@@ -141,47 +206,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
141206

142207
def __init__(
143208
self,
209+
num_embeddings: int,
210+
embedding_dim: int,
211+
padding_idx: Optional[int] = None,
212+
max_norm: Optional[float] = None,
213+
norm_type: float = 2.0,
214+
scale_grad_by_freq: bool = False,
215+
sparse: bool = False,
144216
group_size: int = 32,
145217
scale_precision: torch.dtype = torch.float32,
146218
zero_point_precision: torch.dtype = torch.int32,
147219
*args,
148220
**kwargs,
149221
):
150-
super().__init__(*args, **kwargs)
151-
self.bit_width = 4
152-
self.group_size = group_size
153-
self.scale_precision = scale_precision
154-
self.zero_point_precision = zero_point_precision
155-
self._fake_quant_enabled = True
156-
157-
def forward(self, x):
158-
weight = self.weight
159-
160-
if self._fake_quant_enabled:
161-
(weight_scales, weight_zp) = get_group_qparams_symmetric(
162-
self.weight, self.bit_width, self.group_size, self.scale_precision,
163-
)
164-
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
165-
weight_zp = weight_zp.to(self.zero_point_precision)
166-
(weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width)
167-
w_fq = _fake_quantize_per_channel_group(
168-
self.weight,
169-
weight_scales,
170-
weight_zp,
171-
weight_qmin,
172-
weight_qmax,
173-
self.group_size,
174-
)
175-
else:
176-
w_fq = self.weight
177-
178-
return F.embedding(
179-
x, w_fq, self.padding_idx, self.max_norm,
180-
self.norm_type, self.scale_grad_by_freq, self.sparse,
222+
weight_config = FakeQuantizeConfig(
223+
dtype=TorchAODType.INT4,
224+
group_size=group_size,
225+
is_symmetric=True,
226+
is_dynamic=True,
227+
scale_precision=scale_precision,
228+
zero_point_precision=zero_point_precision,
229+
)
230+
super().__init__(
231+
num_embeddings,
232+
embedding_dim,
233+
padding_idx,
234+
max_norm,
235+
norm_type,
236+
scale_grad_by_freq,
237+
sparse,
238+
weight_config,
239+
*args,
240+
**kwargs,
181241
)
182242

183243
def enable_fake_quant(self, enabled: bool = True):
184-
self._fake_quant_enabled = enabled
244+
self.weight_fake_quantizer.enabled = enabled
185245

186246
def disable_fake_quant(self):
187247
self.enable_fake_quant(False)
@@ -194,25 +254,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
194254
"""
195255
def __init__(
196256
self,
197-
group_size: int,
198-
scale_precision: torch.dtype,
199-
zero_point_precision: torch.dtype,
200-
201-
# nn.Embedding args
202257
num_embeddings: int,
203258
embedding_dim: int,
204259
padding_idx: Optional[int] = None,
205260
max_norm: Optional[float] = None,
206261
norm_type: float = 2.0,
207262
scale_grad_by_freq: bool = False,
208263
sparse: bool = False,
264+
group_size: int = 32,
265+
scale_precision: torch.dtype = torch.float32,
266+
zero_point_precision: torch.dtype = torch.int32,
209267
device: torch.device = None,
210268
):
211269
super().__init__()
212-
self.bit_width = 4
213-
self.group_size = group_size
214-
self.scale_precision = scale_precision
215-
self.zero_point_precision = zero_point_precision
270+
271+
# nn.Embedding args
216272
self.num_embeddings = num_embeddings
217273
self.embedding_dim = embedding_dim
218274
self.padding_idx = padding_idx
@@ -221,6 +277,11 @@ def __init__(
221277
self.scale_grad_by_freq = scale_grad_by_freq
222278
self.sparse = sparse
223279

280+
# quantization args
281+
self.group_size = group_size
282+
self.scale_precision = scale_precision
283+
self.zero_point_precision = zero_point_precision
284+
224285
# currently storing unpacked int8 weights
225286
self.register_buffer(
226287
"weight",
@@ -245,7 +306,7 @@ def __init__(
245306

246307
def forward(self, x):
247308
from torchao._executorch_ops import _quantized_decomposed_dequantize_per_channel_group_wrapper
248-
qmin, qmax = _get_qmin_qmax(self.bit_width)
309+
qmin, qmax = _get_qmin_qmax(4)
249310
w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper(
250311
self.weight,
251312
self.scale,

0 commit comments

Comments
 (0)