Skip to content

Commit 8fd92bc

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
4b embedding quantizer (#3135)
Summary: Pull Request resolved: #3135 4b embedding quantizer Reviewed By: larryliu0820 Differential Revision: D56229021 fbshipit-source-id: 560911333b173b4d03c3c62769e6db4e2ab54c7b
1 parent 944dd4c commit 8fd92bc

File tree

2 files changed

+94
-40
lines changed

2 files changed

+94
-40
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
3838
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
3939

40-
from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler
40+
from .quantize import EmbeddingQuantHandler, WeightOnlyInt8QuantHandler
4141

4242

4343
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
@@ -538,7 +538,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
538538
group_size = int(group_size)
539539
bitwidth = int(bitwidth)
540540
transforms.append(
541-
lambda model: EmbeddingOnlyInt8QuantHandler(
541+
lambda model: EmbeddingQuantHandler(
542542
model, bitwidth=bitwidth, group_size=group_size
543543
).quantized_model()
544544
)

examples/models/llama2/quantize.py

Lines changed: 92 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def dynamically_quantize_per_channel(
124124
return quant, scales, zero_points
125125

126126

127+
#########################################################################
128+
### QuantHandler API definition ###
129+
130+
127131
class QuantHandler:
128132
def __init__(self, mod):
129133
self.mod = mod
@@ -134,8 +138,15 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict"
134138
def convert_for_runtime(self) -> nn.Module:
135139
pass
136140

141+
def quantized_model(self) -> nn.Module:
142+
model_updated_state_dict = self.create_quantized_state_dict()
143+
self.convert_for_runtime()
144+
self.mod.load_state_dict(model_updated_state_dict)
145+
return self.mod
137146

138-
##### Weight-only int8 per-channel quantized code ######
147+
148+
#########################################################################
149+
### Weight-only int8 per-channel quantized code ###
139150

140151

141152
def replace_linear_weight_only_int8_per_channel(module, node_type):
@@ -153,16 +164,17 @@ def replace_linear_weight_only_int8_per_channel(module, node_type):
153164
setattr(
154165
module,
155166
name,
156-
WeightOnlyInt8Linear(child.in_features, child.out_features),
167+
WeightOnlyInt8Linear("cpu", child.in_features, child.out_features),
157168
)
158169
else:
159170
replace_linear_weight_only_int8_per_channel(child, node_type)
160171

161172

162-
class WeightOnlyInt8QuantHandler:
173+
class WeightOnlyInt8QuantHandler(QuantHandler):
163174
def __init__(
164175
self,
165176
mod,
177+
device="cpu",
166178
*,
167179
node_type: str = "*",
168180
bitwidth: Optional[int] = None,
@@ -202,7 +214,7 @@ def create_quantized_state_dict(self) -> Dict:
202214
)
203215
):
204216
print(
205-
f"quantize {self.node_type} {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
217+
f"quantize {self.node_type} {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
206218
)
207219

208220
# print(f"initial weight shape {mod.weight.shape}")
@@ -219,7 +231,7 @@ def create_quantized_state_dict(self) -> Dict:
219231
)
220232

221233
cur_state_dict[f"{fqn}.weight"] = weight
222-
# squeeze makes groupsize=rowsize unidimensional
234+
# squeeze makes group_size=rowsize unidimensional
223235
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
224236

225237
return cur_state_dict
@@ -243,10 +255,10 @@ class WeightOnlyInt8Linear(torch.nn.Module):
243255

244256
def __init__(
245257
self,
258+
device,
246259
in_features: int,
247260
out_features: int,
248261
bias: bool = True,
249-
device=None,
250262
dtype=None,
251263
) -> None:
252264
super().__init__()
@@ -262,11 +274,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
262274
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
263275

264276

265-
##### embedding table quantization ######
277+
#########################################################################
278+
##### embedding table quantization ######
266279

267280

268281
def replace_embedding_weight_only_grouped_int8_per_channel(
269-
module, bitwidth: int = 8, group_size: Optional[int] = None
282+
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
270283
):
271284
for name, child in module.named_children():
272285
# print(f"name: {name}")
@@ -277,25 +290,41 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
277290
module,
278291
name,
279292
QuantizedGroupEmbedding(
293+
device=device,
280294
vocab_size=child.weight.shape[0],
281295
embedding_dim=child.weight.shape[1],
282296
group_size=group_size,
297+
packed=packed,
283298
),
284299
)
285300
else:
286301
replace_embedding_weight_only_grouped_int8_per_channel(
287-
child, bitwidth, group_size
302+
child, device, bitwidth, group_size, packed
288303
)
289304

290305

291-
class EmbeddingOnlyInt8QuantHandler:
292-
def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None):
306+
class EmbeddingQuantHandler(QuantHandler):
307+
def __init__(
308+
self,
309+
mod,
310+
device="cpu",
311+
*,
312+
bitwidth: int = 8,
313+
group_size: Optional[int] = None,
314+
packed=False,
315+
):
316+
if isinstance(packed, str):
317+
packed = packed == "True"
293318
self.mod = mod
319+
self.device = device
294320
self.group_size = group_size
295321
self.bitwidth = bitwidth
322+
self.packed = packed
323+
if (bitwidth != 4) and packed:
324+
raise RuntimeError("pack only works with bitsize 4")
296325

297326
@torch.no_grad()
298-
def create_quantized_state_dict(self) -> Dict:
327+
def create_quantized_state_dict(self, packed=False) -> Dict:
299328
cur_state_dict = self.mod.state_dict()
300329

301330
if self.bitwidth == 4:
@@ -308,18 +337,14 @@ def create_quantized_state_dict(self) -> Dict:
308337
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
309338

310339
for fqn, mod in self.mod.named_modules():
311-
if (
312-
isinstance(mod, nn.Embedding)
313-
or isinstance(mod, fsEmbedding)
314-
or isinstance(mod, fsStandardEmbedding)
315-
):
340+
if isinstance(mod, nn.Embedding):
316341
# print("****")
317342
# print(f"Embedding identified: {fqn, mod}")
318343
# print(f"weights size: {mod.weight.size()}")
319344
# print(f"quantize {fqn}...")
320345

321346
print(
322-
f"quantize {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
347+
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
323348
)
324349
weight, scales, _ = dynamically_quantize_per_channel(
325350
mod.weight.float(),
@@ -330,21 +355,36 @@ def create_quantized_state_dict(self) -> Dict:
330355
scales_dtype=mod.weight.dtype,
331356
)
332357

358+
if packed:
359+
if weight.shape[-1] % 2 != 0:
360+
raise RuntimeError("automatic padding not implemented yet")
361+
362+
weight_range_shifted = weight.add(8).view(torch.uint8)
363+
weight_view = weight_range_shifted.view(
364+
weight.shape[0], weight.shape[1] // 2, 2
365+
)
366+
weight_even = weight_view[:, :, 0] * 16 # left shift 4
367+
weight_odd = weight_view[:, :, 1]
368+
weight_packed = weight_even + weight_odd
369+
weight = weight_packed
370+
371+
weight = weight.to(device=self.device)
372+
scales = scales.to(device=self.device)
333373
# Update state dict
334374
cur_state_dict[f"{fqn}.weight"] = weight
335-
# squeeze makes groupsize=rowsize unidimensional
375+
# squeeze makes group_size=rowsize unidimensional
336376
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
337377

338378
return cur_state_dict
339379

340380
def convert_for_runtime(self) -> nn.Module:
341381
replace_embedding_weight_only_grouped_int8_per_channel(
342-
self.mod, self.bitwidth, self.group_size
382+
self.mod, self.device, self.bitwidth, self.group_size, self.packed
343383
)
344384
return self.mod
345385

346386
def quantized_model(self) -> nn.Module:
347-
model_updated_state_dict = self.create_quantized_state_dict()
387+
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
348388
self.convert_for_runtime()
349389
self.mod.load_state_dict(model_updated_state_dict)
350390
return self.mod
@@ -353,39 +393,53 @@ def quantized_model(self) -> nn.Module:
353393
class QuantizedGroupEmbedding(torch.nn.Module):
354394
def __init__(
355395
self,
396+
device,
356397
vocab_size: int,
357398
embedding_dim: int,
358399
group_size: Optional[int] = None,
359-
device=None,
360400
dtype=torch.half,
401+
packed=False,
361402
) -> None:
362403
super().__init__()
363-
if group_size is None:
404+
if group_size is None or group_size == 0:
364405
group_size = embedding_dim
365406
self.group_size = group_size
366407
self.dtype = dtype
367-
self.register_buffer(
368-
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
369-
)
408+
self.packed = packed
409+
if not packed:
410+
self.register_buffer(
411+
"weight",
412+
torch.empty(
413+
(vocab_size, embedding_dim), dtype=torch.int8, device=device
414+
),
415+
)
416+
else: # packed
417+
self.register_buffer(
418+
"weight",
419+
torch.empty(
420+
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
421+
),
422+
)
370423
groups_per_row = (embedding_dim + group_size - 1) // group_size
371424
if groups_per_row > 1:
372425
self.register_buffer(
373-
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
426+
"scales",
427+
torch.ones(
428+
(vocab_size, groups_per_row), dtype=torch.float16, device=device
429+
),
374430
)
375431
else:
376432
self.register_buffer(
377-
"scales", torch.ones((vocab_size,), dtype=torch.float16)
433+
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
378434
)
379435

380436
@torch.no_grad()
381437
def forward(self, indices: torch.Tensor) -> torch.Tensor:
382-
return torch.ops.quantized_decomposed.embedding_byte.dtype(
383-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
384-
)
385-
386-
387-
# result_weights = self.weight.index_select(0, indices.view(-1))
388-
# result_scales = self.scales.index_select(0, indices.view(-1))
389-
#
390-
# r = result_weights.to(dtype=result_scales.dtype) * result_scales
391-
# return r.view(indices.size() + (-1,))
438+
if not self.packed: # 8bit
439+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
440+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
441+
)
442+
else: # 4bit packed
443+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
444+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
445+
)

0 commit comments

Comments
 (0)