Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="path to the input pruning token mapping file (token_map.json)",
)

parser.add_argument(
"--nncf_compression",
default=False,
action="store_true",
help="If true, stops right after torch.export() and saves the exported model.",
)

parser.add_argument(
"--export_only",
default=False,
Expand Down Expand Up @@ -1207,6 +1214,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
use_legacy_export=llm_config.backend.qnn.enabled,
save_exported_program=llm_config.export.export_only,
verbose=llm_config.debug.verbose,
nncf_compression=llm_config.nncf_compression,
metadata=_load_llama_model_metadata(
WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA,
llm_config.model.use_kv_cache,
Expand Down
32 changes: 32 additions & 0 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch

import nncf
import torch
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
Expand All @@ -40,6 +41,7 @@
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
from torchao.utils import unwrap_tensor_subclass
from functools import partial

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand Down Expand Up @@ -98,6 +100,7 @@ def __init__(
dynamic_shapes: Optional[Any] = None,
use_legacy_export: bool = False,
save_exported_program: bool = False,
nncf_compression: bool = False
):
# Store necessary constructor arguments.
self.model = model
Expand All @@ -119,6 +122,7 @@ def __init__(
self.dynamic_shapes = dynamic_shapes
self.use_legacy_export = use_legacy_export
self.save_exported_program = save_exported_program
self.nncf_compression = nncf_compression

# Note: treat this as the source of truth for the result of
# torch.export'ing a model. If the overall ExportedProgram is needed,
Expand Down Expand Up @@ -428,6 +432,34 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
return self
elif (self.nncf_compression):
tokenizer = get_tokenizer(self.tokenizer_path)

def transform_fn(
prompts: str, tokenizer
):
tokenized_text = tokenizer.encode(prompts, bos=False, eos=False)
logging.error(tokenized_text)

inputs = ()
inputs = (
torch.tensor(tokenized_text).unsqueeze(0),
{"input_pos": torch.tensor([0])},
)

return inputs

self.calibration_data = [self.calibration_data] if isinstance(self.calibration_data, str) else self.calibration_data
self.calibration_data = [word for prompt in self.calibration_data for word in prompt.split()] if not self.dynamic_shapes else self.calibration_data

self.pre_autograd_graph_module = nncf.compress_weights(
self.pre_autograd_graph_module,
dataset=nncf.Dataset(self.calibration_data, transform_func=partial(transform_fn, tokenizer=tokenizer)),
mode=nncf.CompressWeightsMode.INT4_SYM,
ratio=0.8,
sensitivity_metric=nncf.SensitivityMetric.HESSIAN_INPUT_ACTIVATION,
)
return self
else:
logging.info("No quantizer provided, passing...")
return self
Expand Down