diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 04ccbcdea0..6bf019de23 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -370,6 +370,28 @@ def build_args_parser() -> argparse.ArgumentParser: help="Use SpinQuant for better quantization performance. Only support cuda and native.", ) + parser.add_argument( + "--spin_qmode", + type=str, + default=None, + choices=["8da4w"], + help="Quantization mode for SpinQuant. Only support 8da4w right now.", + ) + + parser.add_argument( + "--spin_group_size", + type=int, + default=32, + help="group_size for SpinQuant weight quantization", + ) + + parser.add_argument( + "--spin_embedding_quantize", + default="8,0", + type=str, + help="type of embedding quantization for SpinQuant, ',', e.g., '8,1024'.", + ) + parser.add_argument( "--output_prune_map", default=None, @@ -466,10 +488,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: max_seq_len=args.max_seq_length, output_prune_map_path=args.output_prune_map, metadata_str=args.metadata, + dtype_override=dtype_override, args=args, ) .set_output_dir(output_dir_path) - .to_dtype(dtype_override) .source_transform(_get_source_transforms(modelname, dtype_override, args)) ) @@ -691,6 +713,7 @@ def _load_llama_model( max_seq_len: int = 128, output_prune_map_path: Optional[str] = None, metadata_str: Optional[str] = None, + dtype_override: Optional[DType] = None, args, ) -> "LLMEdgeManager": """ @@ -720,23 +743,32 @@ def _load_llama_model( output_prune_map_path=output_prune_map_path, args=args, ) - state_dict = model.state_dict() - dtype = state_dict[next(iter(state_dict))].dtype - assert dtype in [ - torch.bfloat16, - torch.float16, - torch.float32, - ], f"Only support bfloat16, fp16 or fp32 got {dtype}" - logging.info(f"Loaded model with dtype={dtype}") - - if dtype == torch.bfloat16: - dtype = DType.bf16 - elif dtype == torch.float16: - dtype = DType.fp16 - elif dtype == torch.float32: - dtype = DType.fp32 + if dtype_override: + assert isinstance( + dtype_override, DType + ), "Override dtype needs to be of type " + torch_dtype = dtype_override.to_torch_dtype() + logging.info(f"model.to {torch_dtype}") + model = model.to(dtype=torch_dtype) + dtype = dtype_override else: - raise ValueError(f"Unsupported dtype {dtype}") + state_dict = model.state_dict() + dtype = state_dict[next(iter(state_dict))].dtype + assert dtype in [ + torch.bfloat16, + torch.float16, + torch.float32, + ], f"Only support bfloat16, fp16 or fp32 got {dtype}" + logging.info(f"Loaded model with dtype={dtype}") + + if dtype == torch.bfloat16: + dtype = DType.bf16 + elif dtype == torch.float16: + dtype = DType.fp16 + elif dtype == torch.float32: + dtype = DType.fp32 + else: + raise ValueError(f"Unsupported dtype {dtype}") return LLMEdgeManager( model=model, @@ -769,21 +801,9 @@ def _get_source_transforms( # noqa modelname: str, dtype_override: Optional[DType], args ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: transforms = [] - if args.quantization_mode: - modelname = f"{modelname}_q" - if args.use_spin_quant is None: - transforms.append( - get_quant_weight_transform(args, dtype_override, verbose_export()) - ) - # For SpinQuant, the checkpoints are already quantized - # aka the weights have corresponding scales value, - # So that means, we don't need to apply quantization - # transform. However, we will still need to apply - # transformations that change the model structure to - # match the checkpoint format. - # transform_for_spinquant() will apply these transformations - # later in model.py file. - elif args.use_spin_quant == "cuda": + + if args.use_spin_quant: + if args.use_spin_quant == "cuda": from .source_transformation.spin_quant import ( inject_fast_hadamard_transform_cuda_for_spin_quant, ) @@ -796,7 +816,35 @@ def _get_source_transforms( # noqa transforms.append(inject_fast_hadamard_transform_native_for_spin_quant) + if args.quantization_mode: + """ + When this option is selected, it finds all linear layers and transforms + into quantized linear equivalent module. + + There are cases where the checkpoint is already quantized, for example + on use_spin_quant is enabled. In that case, it will do the appropriate + transformations based on the given checkpoint first. In those cases, + if quantization_mode is enabled, it will quantize any remaining linear + ops that is not quantized. + + There are cases where this may be a no-op, namely, if all linears are + quantized in the checkpoint. + """ + modelname = f"{modelname}_q" + transforms.append( + get_quant_weight_transform(args, dtype_override, verbose_export()) + ) + if args.embedding_quantize: + """ + When this option is selected, it finds all embedding layers and transforms + into quantized embedding equivalent module. + + There are cases where the checkpoint is already quantized, for example + on use_spin_quant is enabled. In that case, it will do the appropriate + transformations based on the given checkpoint first. In those cases, + this wil be a no-op. + """ modelname = f"{modelname}_e" transforms.append(get_quant_embedding_transform(args)) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 21714a9c15..08effca2eb 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -191,16 +191,16 @@ def __init__(self, **kwargs): ) elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: print("Using SPIN quantization.") - assert hasattr(self.args, "group_size"), "group_size must be specified" + assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified" assert hasattr( - self.args, "quantization_mode" - ), "quantization_mode must be specified" + self.args, "spin_group_size" + ), "spin_group_size must be specified" assert hasattr( self.args, "dtype_override" ), "dtype_override must be specified" from .source_transformation.spin_quant import ( sanitize_checkpoint_from_spinquant, - transform_for_spinquant, + transform_linear_for_spinquant, ) mapping = { @@ -209,17 +209,45 @@ def __init__(self, **kwargs): "bf16": torch.bfloat16, } - self.model_ = transform_for_spinquant( + self.model_ = transform_linear_for_spinquant( self.model_, checkpoint, - self.args.group_size, - self.args.quantization_mode, + self.args.spin_group_size, + self.args.spin_qmode, mapping[self.args.dtype_override], ) + embedding_bit_width, embedding_group_size = None, None + if hasattr(self.args, "spin_embedding_quantize"): + embedding_bit_width, embedding_group_size = ( + self.args.spin_embedding_quantize.split(",") + ) + from .source_transformation.spin_quant import ( + transform_embedding_for_spinquant, + ) + + if ( + embedding_group_size == "none" + or embedding_group_size == "None" + or embedding_group_size == "0" + ): + embedding_group_size = None + else: + embedding_group_size = int(embedding_group_size) + + self.model_ = transform_embedding_for_spinquant( + self.model_, + checkpoint, + mapping[self.args.dtype_override], + int(embedding_bit_width), + embedding_group_size, + ) + sanitize_checkpoint_from_spinquant( - checkpoint, - self.args.group_size, + module=self.model_, + checkpoint=checkpoint, + linear_group_size=self.args.spin_group_size, + embedding_group_size=embedding_group_size, ) # assign=True: load params/buffers by assignment instead of performing an in-place copy. diff --git a/examples/models/llama2/source_transformation/spin_quant.py b/examples/models/llama2/source_transformation/spin_quant.py index 1dbf878dc6..b6107492d2 100644 --- a/examples/models/llama2/source_transformation/spin_quant.py +++ b/examples/models/llama2/source_transformation/spin_quant.py @@ -9,7 +9,7 @@ # Helper functions for tranforming the model to be able to run SpinQuant. # See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. -from typing import Any +from typing import Any, Optional import torch @@ -20,6 +20,8 @@ from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from .quantize import QuantizedGroupEmbedding + def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): """ @@ -123,7 +125,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) -def transform_for_spinquant( +def transform_linear_for_spinquant( module: torch.nn.Module, checkpoint: Any, group_size: int, @@ -151,9 +153,64 @@ def transform_for_spinquant( return module +def _replace_embedding_with_quantized_group_embedding_for_spinquant( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, + bit_width: int, + group_size: Optional[int] = None, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + # Only replace embedding layers where the checkpoint contains explicit scales + scales_key = f"{cur_fqn}.scale" + if isinstance(child, nn.Embedding) and scales_key in checkpoint: + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == torch.float32 + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_embedding = QuantizedGroupEmbedding( + device=child.weight.device, + vocab_size=child.weight.shape[0], + embedding_dim=child.weight.shape[1], + group_size=group_size, + dtype=dtype, + packed=False, # TODO(lunwenh): support packed embedding for SpinQuant + ) + return new_embedding + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_embedding_for_spinquant( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, + bit_width: int, + group_size: Optional[int] = None, +) -> torch.nn.Module: + """ + Transform the model to be able to load SpinQuant checkpoints that + are quantized with the given bit_width and group size for embedding. + """ + if group_size is not None and group_size not in [0, 32, 64, 128, 256]: + raise ValueError(f"Group size {group_size} is not supported for SpinQuant.") + _replace_embedding_with_quantized_group_embedding_for_spinquant( + module, + checkpoint, + dtype, + bit_width, + group_size, + ) + return module + + def sanitize_checkpoint_from_spinquant( + module: torch.nn.Module, checkpoint: Any, - group_size: int, + linear_group_size: int, + embedding_group_size: Optional[int] = None, ): """ Sanitize the SpinQuant checkpoint. @@ -173,7 +230,31 @@ def sanitize_checkpoint_from_spinquant( for old_key, new_key in keys_to_rename: old_val = checkpoint.pop(old_key) - checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size] + module_name = new_key[0 : new_key.rfind(".")] + sub_module = module.get_submodule(module_name) + assert sub_module is not None + assert isinstance(sub_module, Int8DynActInt4WeightLinear) or isinstance( + sub_module, QuantizedGroupEmbedding + ) + # Checkpoints with SpinQuant could come with two formats for scales: + # 1. scales is grouped by group size + # 2. scales is not grouped by group size + # We need to handle both cases here. + # TODO(lunwenh): remove this once we have a unified format for scales. + if isinstance(sub_module, Int8DynActInt4WeightLinear): + checkpoint[new_key] = ( + old_val if linear_group_size == -1 else old_val[:, ::linear_group_size] + ) + elif isinstance(sub_module, QuantizedGroupEmbedding): + if ( + embedding_group_size is None or embedding_group_size == 0 + ): # Scales are not grouped + checkpoint[new_key] = old_val[:, 0] + elif embedding_group_size == -1: # Scales are grouped by group size + checkpoint[new_key] = old_val + else: + checkpoint[new_key] = old_val[:, ::embedding_group_size] + for k in keys_to_remove: checkpoint.pop(k) for k, v in checkpoint.items(): diff --git a/examples/models/llama2/tests/test_spinquant_transforms.py b/examples/models/llama2/tests/test_spinquant_transforms.py index bd56632c5f..745bd6f46a 100644 --- a/examples/models/llama2/tests/test_spinquant_transforms.py +++ b/examples/models/llama2/tests/test_spinquant_transforms.py @@ -10,24 +10,15 @@ from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer from executorch.examples.models.llama2.source_transformation.spin_quant import ( sanitize_checkpoint_from_spinquant, - transform_for_spinquant, + transform_embedding_for_spinquant, + transform_linear_for_spinquant, ) from torchao.quantization.utils import group_quantize_tensor_symmetric class SpinQuantTests(unittest.TestCase): - def test_transforms_for_spinquant(self): - - # Step 1: Create llama class with dummy weights - params = { - "dim": 768, - "multiple_of": 32, - "n_heads": 12, - "n_layers": 12, - "norm_eps": 1e-05, - "vocab_size": 32000, - } + def _prepare_dummy_model(self) -> Transformer: model_args = ModelArgs( max_seq_len=2048, max_batch_size=1, @@ -35,10 +26,22 @@ def test_transforms_for_spinquant(self): use_sdpa_with_kv_cache_op=False, generate_full_logits=False, enable_dynamic_shape=True, - **params, + dim=768, + multiple_of=32, + n_heads=12, + n_layers=12, + norm_eps=1e-05, + vocab_size=32000, ) model = Transformer(model_args) + + return model + + def test_transform_linear_for_spinquant(self): + + # Step 1: Create llama class with dummy weights + model = self._prepare_dummy_model() checkpoint = model.state_dict() # Step 2: @@ -63,7 +66,7 @@ def test_transforms_for_spinquant(self): # Step 3: # Transform the model so that it is compatible with the new checkpoint - transform_for_spinquant( + transform_linear_for_spinquant( model, checkpoint, 32, @@ -71,6 +74,7 @@ def test_transforms_for_spinquant(self): torch.float32, ) sanitize_checkpoint_from_spinquant( + model, checkpoint, -1, ) @@ -87,3 +91,58 @@ def test_transforms_for_spinquant(self): # The new_checkpoint contains zeros so # have to iterate over the keys. self.assertTrue(torch.allclose(new_checkpoint[k], v)) + + def test_transform_embedding_for_spinquant(self): + + # Step 1: Create llama class with dummy weights + model = self._prepare_dummy_model() + checkpoint = model.state_dict() + + # Step 2: + # Do group-wise quantization and amend the checkpoints with + # int8 weight and fp32 scales + group_size = 32 + n_bit = 4 + scales_precision = torch.float32 + for fqn, mod in model.named_modules(): + # Quantize everything except the last layer + if isinstance(mod, torch.nn.Embedding): + weight = mod.weight.data + ( + weight_int8, + scales, + zeros, + ) = group_quantize_tensor_symmetric( + weight.to(torch.float32), n_bit, group_size, scales_precision + ) + checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu") + checkpoint[f"{fqn}.scale"] = scales.to("cpu") + + # Step 3: + # Transform the model so that it is compatible with the new checkpoint + transform_embedding_for_spinquant( + model, + checkpoint, + torch.float32, + n_bit, + group_size, + ) + sanitize_checkpoint_from_spinquant( + module=model, + checkpoint=checkpoint, + linear_group_size=-1, + embedding_group_size=-1, + ) + + model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + + new_checkpoint = model.state_dict() + + for k, v in checkpoint.items(): + # The new_checkpoint contains zeros so + # have to iterate over the keys. + self.assertTrue(torch.allclose(new_checkpoint[k], v))