Skip to content

Commit 0b71b8d

Browse files
authored
Add generic fake quantized embedding for QAT (#1085)
Summary: This is equivalent to #1020 but for nn.Embedding. This commit adds a generic fake quantized embedding module to replace the uses of the existing more specific QAT embeddings. For example, `Int4WeightOnlyQATEmbedding` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.embedding import FakeQuantizedEmbedding weight_config = FakeQuantizeConfig( dtype=torch.int4, group_size=group_size, is_symmetric=True, ) fq_embedding = FakeQuantizedEmbedding(16, 32, weight_config=weight_config) ``` Test Plan: python test/quantization/test_qat.py -k test_qat_4w_embedding python test/quantization/test_qat.py -k test_fake_quantized_embedding_4w
1 parent 7a35695 commit 0b71b8d

File tree

2 files changed

+153
-53
lines changed

2 files changed

+153
-53
lines changed

test/quantization/test_qat.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from torchao.quantization.prototype.qat.fake_quantizer import (
3030
FakeQuantizer,
3131
)
32+
from torchao.quantization.prototype.qat.embedding import (
33+
FakeQuantizedEmbedding,
34+
)
3235
from torchao.quantization.prototype.qat.linear import (
3336
FakeQuantizedLinear,
3437
)
@@ -852,6 +855,40 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
852855
baseline_out = linear_forward_4w(x2, fq_linear.weight)
853856
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
854857

858+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
859+
def test_fake_quantized_embedding_4w(self):
860+
"""
861+
Test that we can express int4 per group symmetric weight only fake quantization
862+
with `FakeQuantizedEmbedding`.
863+
"""
864+
num_embeddings = 64
865+
embedding_dim = 128
866+
group_size = 32
867+
torch.manual_seed(self.SEED)
868+
fq_embedding = FakeQuantizedEmbedding(
869+
num_embeddings,
870+
embedding_dim,
871+
weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size),
872+
)
873+
874+
def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
875+
"""
876+
Baseline for int4 per group symmetric weight only fake quantization.
877+
"""
878+
(s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32)
879+
zp = zp.to(torch.int32)
880+
(qmin, qmax) = _get_qmin_qmax(4)
881+
w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size)
882+
return F.embedding(x, w_fq)
883+
884+
# Compare embedding values
885+
torch.manual_seed(self.SEED)
886+
x = torch.randint(num_embeddings, (5, 10))
887+
x2 = copy.deepcopy(x)
888+
fq_out = fq_embedding(x)
889+
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
890+
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
891+
855892

856893
if __name__ == "__main__":
857894
unittest.main()

torchao/quantization/prototype/qat/embedding.py

Lines changed: 116 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,73 @@
1414
from torchao.quantization.quant_api import (
1515
_replace_with_custom_fn_if_matches_filter,
1616
)
17+
from torchao.quantization.quant_primitives import TorchAODType
18+
from .api import FakeQuantizeConfig
19+
from .fake_quantizer import FakeQuantizer
1720
from .utils import (
1821
_fake_quantize_per_channel_group,
1922
_get_qmin_qmax,
2023
)
2124

2225

26+
class FakeQuantizedEmbedding(torch.nn.Embedding):
27+
"""
28+
General embedding layer with fake quantized weights.
29+
30+
Specific target dtypes, granularity, schemes etc. are specified
31+
through separate configs for weights and activations.
32+
33+
Example usage::
34+
35+
weight_config = FakeQuantizeConfig(
36+
dtype=torch.int4,
37+
group_size=8,
38+
symmetric=True,
39+
)
40+
fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
41+
fq_embedding(torch.LongTensor([3]))
42+
"""
43+
44+
def __init__(
45+
self,
46+
num_embeddings: int,
47+
embedding_dim: int,
48+
padding_idx: Optional[int] = None,
49+
max_norm: Optional[float] = None,
50+
norm_type: float = 2.0,
51+
scale_grad_by_freq: bool = False,
52+
sparse: bool = False,
53+
weight_config: Optional[FakeQuantizeConfig] = None,
54+
*args,
55+
**kwargs,
56+
) -> None:
57+
super().__init__(
58+
num_embeddings,
59+
embedding_dim,
60+
padding_idx,
61+
max_norm,
62+
norm_type,
63+
scale_grad_by_freq,
64+
sparse,
65+
*args,
66+
**kwargs,
67+
)
68+
if weight_config is not None:
69+
self.weight_fake_quantizer = FakeQuantizer(weight_config)
70+
else:
71+
self.weight_fake_quantizer = None
72+
73+
def forward(self, x: torch.Tensor) -> torch.Tensor:
74+
if self.weight_fake_quantizer is not None:
75+
w = self.weight_fake_quantizer(self.weight)
76+
else:
77+
w = self.weight
78+
return F.embedding(
79+
x, w, self.padding_idx, self.max_norm,
80+
self.norm_type, self.scale_grad_by_freq, self.sparse,
81+
)
82+
83+
2384
# ======================================
2485
# | Embedding int4 weight-only QAT |
2586
# ======================================
@@ -40,7 +101,7 @@ def __init__(
40101
self.bit_width = 4
41102
self.group_size: int = group_size
42103
self.scale_precision: torch.dtype = scale_precision
43-
self.zero_point_precision: torch.dtype = zero_point_precision,
104+
self.zero_point_precision: torch.dtype = zero_point_precision
44105

45106
def prepare(
46107
self,
@@ -56,16 +117,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
56117

57118
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
58119
new_embedding = Int4WeightOnlyQATEmbedding(
59-
group_size=self.group_size,
60-
61-
# other nn.Embedding args
120+
# nn.Embedding args
62121
num_embeddings=child.num_embeddings,
63122
embedding_dim=child.embedding_dim,
64123
padding_idx=child.padding_idx,
65124
max_norm=child.max_norm,
66125
norm_type=child.norm_type,
67126
scale_grad_by_freq=child.scale_grad_by_freq,
68127
sparse=child.sparse,
128+
# quantization args
129+
group_size=self.group_size,
130+
scale_precision=self.scale_precision,
131+
zero_point_precision=self.zero_point_precision,
69132
device=child.weight.device,
70133
)
71134
# In distributed training, the model may be instantiated
@@ -98,28 +161,31 @@ def _convert_helper(self, module: torch.nn.Module):
98161
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
99162
for name, child in module.named_children():
100163
if isinstance(child, Int4WeightOnlyQATEmbedding):
164+
group_size = child.weight_fake_quantizer.config.group_size
165+
scale_precision = child.weight_fake_quantizer.config.scale_precision
166+
zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision
101167
quantized_embedding = Int4WeightOnlyEmbedding(
102-
group_size=child.group_size,
103-
scale_precision=child.scale_precision,
104-
zero_point_precision=child.zero_point_precision,
105-
106-
# other nn.Embedding args
168+
# nn.Embedding args
107169
num_embeddings=child.num_embeddings,
108170
embedding_dim=child.embedding_dim,
109171
padding_idx=child.padding_idx,
110172
max_norm=child.max_norm,
111173
norm_type=child.norm_type,
112174
scale_grad_by_freq=child.scale_grad_by_freq,
113175
sparse=child.sparse,
176+
# quantization args
177+
group_size=group_size,
178+
scale_precision=scale_precision,
179+
zero_point_precision=zero_point_precision,
114180
device=child.weight.device,
115181
)
116182
setattr(module, name, quantized_embedding)
117183

118184
# Load weights and qparams into quantized embedding
119185
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
120-
(s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, child.group_size)
186+
(s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size)
121187
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
122-
child.weight, s, zp, qmin, qmax, torch.int8, child.group_size,
188+
child.weight, s, zp, qmin, qmax, torch.int8, group_size,
123189
)
124190
quantized_embedding.weight = q_weight
125191
quantized_embedding.scales = s
@@ -128,7 +194,7 @@ def _convert_helper(self, module: torch.nn.Module):
128194
self._convert_helper(child)
129195

130196

131-
class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
197+
class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
132198
"""
133199
This module implements a embedding layer with int4 fake quantized
134200
grouped per channel weights.
@@ -141,47 +207,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
141207

142208
def __init__(
143209
self,
210+
num_embeddings: int,
211+
embedding_dim: int,
212+
padding_idx: Optional[int] = None,
213+
max_norm: Optional[float] = None,
214+
norm_type: float = 2.0,
215+
scale_grad_by_freq: bool = False,
216+
sparse: bool = False,
144217
group_size: int = 32,
145218
scale_precision: torch.dtype = torch.float32,
146219
zero_point_precision: torch.dtype = torch.int32,
147220
*args,
148221
**kwargs,
149222
):
150-
super().__init__(*args, **kwargs)
151-
self.bit_width = 4
152-
self.group_size = group_size
153-
self.scale_precision = scale_precision
154-
self.zero_point_precision = zero_point_precision
155-
self._fake_quant_enabled = True
156-
157-
def forward(self, x):
158-
weight = self.weight
159-
160-
if self._fake_quant_enabled:
161-
(weight_scales, weight_zp) = get_group_qparams_symmetric(
162-
self.weight, self.bit_width, self.group_size, self.scale_precision,
163-
)
164-
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
165-
weight_zp = weight_zp.to(self.zero_point_precision)
166-
(weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width)
167-
w_fq = _fake_quantize_per_channel_group(
168-
self.weight,
169-
weight_scales,
170-
weight_zp,
171-
weight_qmin,
172-
weight_qmax,
173-
self.group_size,
174-
)
175-
else:
176-
w_fq = self.weight
177-
178-
return F.embedding(
179-
x, w_fq, self.padding_idx, self.max_norm,
180-
self.norm_type, self.scale_grad_by_freq, self.sparse,
223+
weight_config = FakeQuantizeConfig(
224+
dtype=TorchAODType.INT4,
225+
group_size=group_size,
226+
is_symmetric=True,
227+
is_dynamic=True,
228+
scale_precision=scale_precision,
229+
zero_point_precision=zero_point_precision,
230+
)
231+
super().__init__(
232+
num_embeddings,
233+
embedding_dim,
234+
padding_idx,
235+
max_norm,
236+
norm_type,
237+
scale_grad_by_freq,
238+
sparse,
239+
weight_config,
240+
*args,
241+
**kwargs,
181242
)
182243

183244
def enable_fake_quant(self, enabled: bool = True):
184-
self._fake_quant_enabled = enabled
245+
self.weight_fake_quantizer.enabled = enabled
185246

186247
def disable_fake_quant(self):
187248
self.enable_fake_quant(False)
@@ -194,25 +255,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
194255
"""
195256
def __init__(
196257
self,
197-
group_size: int,
198-
scale_precision: torch.dtype,
199-
zero_point_precision: torch.dtype,
200-
201-
# nn.Embedding args
202258
num_embeddings: int,
203259
embedding_dim: int,
204260
padding_idx: Optional[int] = None,
205261
max_norm: Optional[float] = None,
206262
norm_type: float = 2.0,
207263
scale_grad_by_freq: bool = False,
208264
sparse: bool = False,
265+
group_size: int = 32,
266+
scale_precision: torch.dtype = torch.float32,
267+
zero_point_precision: torch.dtype = torch.int32,
209268
device: torch.device = None,
210269
):
211270
super().__init__()
212-
self.bit_width = 4
213-
self.group_size = group_size
214-
self.scale_precision = scale_precision
215-
self.zero_point_precision = zero_point_precision
271+
272+
# nn.Embedding args
216273
self.num_embeddings = num_embeddings
217274
self.embedding_dim = embedding_dim
218275
self.padding_idx = padding_idx
@@ -221,6 +278,12 @@ def __init__(
221278
self.scale_grad_by_freq = scale_grad_by_freq
222279
self.sparse = sparse
223280

281+
# quantization args
282+
self.bit_width = 4
283+
self.group_size = group_size
284+
self.scale_precision = scale_precision
285+
self.zero_point_precision = zero_point_precision
286+
224287
# currently storing unpacked int8 weights
225288
self.register_buffer(
226289
"weight",

0 commit comments

Comments
 (0)