Skip to content

Updates torchao pin to enable shared embedding quantization #9548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ python -m examples.models.llama.export_llama \
-d fp32
```

A few notes:
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized with weight zeros or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and uses weight zeros (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32, but is quantized with scales-only. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but what's the plan for updating our arg selection scheme for quant?

-E "torchao:4,32,true isn't user friendly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd never need to do that. true is the default (and existing behavior), so you could continue to use -E"torchao:4,32".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd make this a bit more clear that shared is only for torchao kernels, or torchao:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's under the torchao section of the docs.

- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.

Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.

The first step is to install ExecuTorch (the same as step 3.1 above):
Expand Down
42 changes: 29 additions & 13 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def build_args_parser() -> argparse.ArgumentParser:
type=str,
help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)
parser.add_argument(
"--use_shared_embedding",
action="store_true",
help="Whether the embedding/unembedding weights should be shared. Only available with torchao kernels.",
)
parser.add_argument(
"--pt2e_quantize",
default=None,
Expand Down Expand Up @@ -684,6 +689,15 @@ def _validate_args(args):
if args.num_sharding > 0 and not args.qnn:
raise ValueError("Model shard is only supported with qnn backend now.")

if args.use_shared_embedding:
if not (
args.embedding_quantize is not None
and args.embedding_quantize.startswith("torchao:")
Comment on lines +692 to +695
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.use_shared_embedding:
if not (
args.embedding_quantize is not None
and args.embedding_quantize.startswith("torchao:")
if args.use_shared_embedding:
and (
args.embedding_quantize is None
or not args.embedding_quantize.startswith("torchao:")

nit: nested conditionals into an error

):
raise ValueError(
"Shared embedding is only supported with torchao quantization."
)

if (
args.quantization_mode is not None
and args.quantization_mode.startswith("torchao:")
Expand Down Expand Up @@ -1122,6 +1136,21 @@ def _get_source_transforms( # noqa

transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)

if args.embedding_quantize:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we change the order of the source transform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared_embedding must be applied before linear. So I changed order to embedding first, and linear second. I put a code comment to this effect as well.

"""
When this option is selected, it finds all embedding layers and transforms
into quantized embedding equivalent module.

There are cases where the checkpoint is already quantized, for example
on use_spin_quant is enabled. In that case, it will do the appropriate
transformations based on the given checkpoint first. In those cases,
this wil be a no-op.
"""
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))

# quantization_mode should be applied after embedding_quantize
# to support shared_embedding
if args.quantization_mode:
"""
When this option is selected, it finds all linear layers and transforms
Expand All @@ -1145,19 +1174,6 @@ def _get_source_transforms( # noqa
)
)

if args.embedding_quantize:
"""
When this option is selected, it finds all embedding layers and transforms
into quantized embedding equivalent module.

There are cases where the checkpoint is already quantized, for example
on use_spin_quant is enabled. In that case, it will do the appropriate
transformations based on the given checkpoint first. In those cases,
this wil be a no-op.
"""
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))

if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

Expand Down
44 changes: 33 additions & 11 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def quantize( # noqa C901
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=getattr(torch, f"int{bitwidth}"),
granularity=(
PerRow() if group_size in [0, -1] else PerGroup(group_size)
),
granularity=(PerRow() if group_size == 0 else PerGroup(group_size)),
has_weight_zeros=False,
),
)
Expand Down Expand Up @@ -786,19 +784,43 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:

def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
if args.embedding_quantize.startswith("torchao:"):
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
from torchao.experimental.quant_api import (
EmbeddingQuantizer,
SharedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerGroup, PerRow

quant_args = args.embedding_quantize.split(":")[1].split(",")
if len(quant_args) == 2:
bitwidth, group_size = quant_args
has_weight_zeros = True
else:
bitwidth, group_size, has_weight_zeros = quant_args

if group_size in ["none", "None", "0"]:
group_size = 0

group_size = int(group_size)
bitwidth = int(bitwidth)
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
has_weight_zeros = bool(has_weight_zeros)
weight_dtype = getattr(torch, f"int{bitwidth}")
granularity = PerRow() if group_size == 0 else PerGroup(group_size)

def _torchao_embedding_quantizer(model):
with torch.no_grad():
model = IntxWeightEmbeddingQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=bitwidth,
groupsize=group_size,
).quantize(model)
if not args.use_shared_embedding:
EmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
use_fallback=False,
).quantize(model)
else:
SharedEmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
).quantize(model)
return model

return _torchao_embedding_quantizer
Expand Down
2 changes: 1 addition & 1 deletion third-party/ao
Submodule ao updated 23 files
+29 −29 .github/workflows/float8nocompile_test.yaml
+1 −1 .github/workflows/torchao_experimental_test.yml
+3 −2 examples/sam2_amg_server/compile_export_utils.py
+3 −2 examples/sam2_vos_example/compile_export_utils.py
+1 −1 torchao/dtypes/uintx/tensor_core_tiled_layout.py
+58 −0 torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h
+22 −75 ...al/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h
+19 −148 ...al/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h
+19 −149 ...al/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h
+467 −0 torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h
+9 −0 torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt
+1 −0 torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh
+60 −1 torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp
+118 −0 torchao/experimental/kernels/cpu/aarch64/tests/test_weight_packing.cpp
+119 −0 torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
+35 −19 torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp
+28 −0 torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp
+1 −63 torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h
+75 −0 torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h
+2 −1 torchao/experimental/ops/mps/test/test_lowbit.py
+26 −16 torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py
+370 −72 torchao/experimental/quant_api.py
+129 −28 torchao/experimental/tests/test_embedding_xbit_quantizer.py
Loading