Skip to content

Commit

Permalink
Transform embedding from SpinQuant checkpoint (#5552)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5552

This diff updates the llama export part to be able to load SpinQuant checkpoint which has all linear layers and embedding table quantized.

Reviewed By: mergennachin

Differential Revision: D62665632

fbshipit-source-id: 3e8cda37ac16b65543e3123ea59526352ac6a70c
  • Loading branch information
Lunwen He authored and facebook-github-bot committed Sep 24, 2024
1 parent 72245c3 commit 3e79ea4
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 59 deletions.
112 changes: 80 additions & 32 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,28 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
)

parser.add_argument(
"--spin_qmode",
type=str,
default=None,
choices=["8da4w"],
help="Quantization mode for SpinQuant. Only support 8da4w right now.",
)

parser.add_argument(
"--spin_group_size",
type=int,
default=32,
help="group_size for SpinQuant weight quantization",
)

parser.add_argument(
"--spin_embedding_quantize",
default="8,0",
type=str,
help="type of embedding quantization for SpinQuant, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)

parser.add_argument(
"--output_prune_map",
default=None,
Expand Down Expand Up @@ -466,10 +488,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
max_seq_len=args.max_seq_length,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
dtype_override=dtype_override,
args=args,
)
.set_output_dir(output_dir_path)
.to_dtype(dtype_override)
.source_transform(_get_source_transforms(modelname, dtype_override, args))
)

Expand Down Expand Up @@ -691,6 +713,7 @@ def _load_llama_model(
max_seq_len: int = 128,
output_prune_map_path: Optional[str] = None,
metadata_str: Optional[str] = None,
dtype_override: Optional[DType] = None,
args,
) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -720,23 +743,32 @@ def _load_llama_model(
output_prune_map_path=output_prune_map_path,
args=args,
)
state_dict = model.state_dict()
dtype = state_dict[next(iter(state_dict))].dtype
assert dtype in [
torch.bfloat16,
torch.float16,
torch.float32,
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
logging.info(f"Loaded model with dtype={dtype}")

if dtype == torch.bfloat16:
dtype = DType.bf16
elif dtype == torch.float16:
dtype = DType.fp16
elif dtype == torch.float32:
dtype = DType.fp32
if dtype_override:
assert isinstance(
dtype_override, DType
), "Override dtype needs to be of type <DType>"
torch_dtype = dtype_override.to_torch_dtype()
logging.info(f"model.to {torch_dtype}")
model = model.to(dtype=torch_dtype)
dtype = dtype_override
else:
raise ValueError(f"Unsupported dtype {dtype}")
state_dict = model.state_dict()
dtype = state_dict[next(iter(state_dict))].dtype
assert dtype in [
torch.bfloat16,
torch.float16,
torch.float32,
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
logging.info(f"Loaded model with dtype={dtype}")

if dtype == torch.bfloat16:
dtype = DType.bf16
elif dtype == torch.float16:
dtype = DType.fp16
elif dtype == torch.float32:
dtype = DType.fp32
else:
raise ValueError(f"Unsupported dtype {dtype}")

return LLMEdgeManager(
model=model,
Expand Down Expand Up @@ -769,21 +801,9 @@ def _get_source_transforms( # noqa
modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
transforms = []
if args.quantization_mode:
modelname = f"{modelname}_q"
if args.use_spin_quant is None:
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)
# For SpinQuant, the checkpoints are already quantized
# aka the weights have corresponding scales value,
# So that means, we don't need to apply quantization
# transform. However, we will still need to apply
# transformations that change the model structure to
# match the checkpoint format.
# transform_for_spinquant() will apply these transformations
# later in model.py file.
elif args.use_spin_quant == "cuda":

if args.use_spin_quant:
if args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)
Expand All @@ -796,7 +816,35 @@ def _get_source_transforms( # noqa

transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)

if args.quantization_mode:
"""
When this option is selected, it finds all linear layers and transforms
into quantized linear equivalent module.
There are cases where the checkpoint is already quantized, for example
on use_spin_quant is enabled. In that case, it will do the appropriate
transformations based on the given checkpoint first. In those cases,
if quantization_mode is enabled, it will quantize any remaining linear
ops that is not quantized.
There are cases where this may be a no-op, namely, if all linears are
quantized in the checkpoint.
"""
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)

if args.embedding_quantize:
"""
When this option is selected, it finds all embedding layers and transforms
into quantized embedding equivalent module.
There are cases where the checkpoint is already quantized, for example
on use_spin_quant is enabled. In that case, it will do the appropriate
transformations based on the given checkpoint first. In those cases,
this wil be a no-op.
"""
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args))

Expand Down
46 changes: 37 additions & 9 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,16 @@ def __init__(self, **kwargs):
)
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
print("Using SPIN quantization.")
assert hasattr(self.args, "group_size"), "group_size must be specified"
assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified"
assert hasattr(
self.args, "quantization_mode"
), "quantization_mode must be specified"
self.args, "spin_group_size"
), "spin_group_size must be specified"
assert hasattr(
self.args, "dtype_override"
), "dtype_override must be specified"
from .source_transformation.spin_quant import (
sanitize_checkpoint_from_spinquant,
transform_for_spinquant,
transform_linear_for_spinquant,
)

mapping = {
Expand All @@ -209,17 +209,45 @@ def __init__(self, **kwargs):
"bf16": torch.bfloat16,
}

self.model_ = transform_for_spinquant(
self.model_ = transform_linear_for_spinquant(
self.model_,
checkpoint,
self.args.group_size,
self.args.quantization_mode,
self.args.spin_group_size,
self.args.spin_qmode,
mapping[self.args.dtype_override],
)

embedding_bit_width, embedding_group_size = None, None
if hasattr(self.args, "spin_embedding_quantize"):
embedding_bit_width, embedding_group_size = (
self.args.spin_embedding_quantize.split(",")
)
from .source_transformation.spin_quant import (
transform_embedding_for_spinquant,
)

if (
embedding_group_size == "none"
or embedding_group_size == "None"
or embedding_group_size == "0"
):
embedding_group_size = None
else:
embedding_group_size = int(embedding_group_size)

self.model_ = transform_embedding_for_spinquant(
self.model_,
checkpoint,
mapping[self.args.dtype_override],
int(embedding_bit_width),
embedding_group_size,
)

sanitize_checkpoint_from_spinquant(
checkpoint,
self.args.group_size,
module=self.model_,
checkpoint=checkpoint,
linear_group_size=self.args.spin_group_size,
embedding_group_size=embedding_group_size,
)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
Expand Down
89 changes: 85 additions & 4 deletions examples/models/llama2/source_transformation/spin_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Helper functions for tranforming the model to be able to run SpinQuant.
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.

from typing import Any
from typing import Any, Optional

import torch

Expand All @@ -20,6 +20,8 @@
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

from .quantize import QuantizedGroupEmbedding


def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
"""
Expand Down Expand Up @@ -123,7 +125,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_for_spinquant(
def transform_linear_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
Expand Down Expand Up @@ -151,9 +153,64 @@ def transform_for_spinquant(
return module


def _replace_embedding_with_quantized_group_embedding_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
bit_width: int,
group_size: Optional[int] = None,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
# Only replace embedding layers where the checkpoint contains explicit scales
scales_key = f"{cur_fqn}.scale"
if isinstance(child, nn.Embedding) and scales_key in checkpoint:
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
assert checkpoint[scales_key].dtype == torch.float32
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_embedding = QuantizedGroupEmbedding(
device=child.weight.device,
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
group_size=group_size,
dtype=dtype,
packed=False, # TODO(lunwenh): support packed embedding for SpinQuant
)
return new_embedding

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_embedding_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
bit_width: int,
group_size: Optional[int] = None,
) -> torch.nn.Module:
"""
Transform the model to be able to load SpinQuant checkpoints that
are quantized with the given bit_width and group size for embedding.
"""
if group_size is not None and group_size not in [0, 32, 64, 128, 256]:
raise ValueError(f"Group size {group_size} is not supported for SpinQuant.")
_replace_embedding_with_quantized_group_embedding_for_spinquant(
module,
checkpoint,
dtype,
bit_width,
group_size,
)
return module


def sanitize_checkpoint_from_spinquant(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
linear_group_size: int,
embedding_group_size: Optional[int] = None,
):
"""
Sanitize the SpinQuant checkpoint.
Expand All @@ -173,7 +230,31 @@ def sanitize_checkpoint_from_spinquant(

for old_key, new_key in keys_to_rename:
old_val = checkpoint.pop(old_key)
checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size]
module_name = new_key[0 : new_key.rfind(".")]
sub_module = module.get_submodule(module_name)
assert sub_module is not None
assert isinstance(sub_module, Int8DynActInt4WeightLinear) or isinstance(
sub_module, QuantizedGroupEmbedding
)
# Checkpoints with SpinQuant could come with two formats for scales:
# 1. scales is grouped by group size
# 2. scales is not grouped by group size
# We need to handle both cases here.
# TODO(lunwenh): remove this once we have a unified format for scales.
if isinstance(sub_module, Int8DynActInt4WeightLinear):
checkpoint[new_key] = (
old_val if linear_group_size == -1 else old_val[:, ::linear_group_size]
)
elif isinstance(sub_module, QuantizedGroupEmbedding):
if (
embedding_group_size is None or embedding_group_size == 0
): # Scales are not grouped
checkpoint[new_key] = old_val[:, 0]
elif embedding_group_size == -1: # Scales are grouped by group size
checkpoint[new_key] = old_val
else:
checkpoint[new_key] = old_val[:, ::embedding_group_size]

for k in keys_to_remove:
checkpoint.pop(k)
for k, v in checkpoint.items():
Expand Down
Loading

0 comments on commit 3e79ea4

Please sign in to comment.