Skip to content

Move get_quantizer_and_quant_params to quantizer_lib #11056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
import torch

from datasets import load_dataset
from executorch.examples.models.llama.export_llama_lib import (
get_quantizer_and_quant_params,
)

from executorch.extension.llm.export.builder import LLMEdgeManager
from executorch.extension.llm.export.quantizer_lib import get_quantizer_and_quant_params
from lm_eval.evaluator import simple_evaluate
from pytorch_tokenizers import get_tokenizer
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
Expand Down
34 changes: 1 addition & 33 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,7 @@
get_xnnpack_partitioner,
)

from executorch.extension.llm.export.quantizer_lib import (
get_coreml_quantizer,
get_pt2e_quantization_params,
get_pt2e_quantizers,
get_qnn_quantizer,
get_vulkan_quantizer,
)
from executorch.extension.llm.export.quantizer_lib import get_quantizer_and_quant_params
from executorch.util.activation_memory_profiler import generate_memory_trace

from ..model_factory import EagerModelFactory
Expand Down Expand Up @@ -726,32 +720,6 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
return edge_manager


def get_quantizer_and_quant_params(args):
pt2e_quant_params = get_pt2e_quantization_params(
args.pt2e_quantize, args.quantization_mode
)
quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
quant_dtype = None
if args.qnn and args.pt2e_quantize:
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
qnn_quantizer, quant_dtype = get_qnn_quantizer(
args.pt2e_quantize, args.quantization_mode
)
quantizers.append(qnn_quantizer)
if args.coreml and args.pt2e_quantize:
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
quantizers.append(coreml_quantizer)
if args.vulkan and args.pt2e_quantize:
assert (
len(quantizers) == 0
), "Should not enable both vulkan and other quantizers"
vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize)
quantizers.append(vulkan_quantizer)
logging.info(f"Applying quantizers: {quantizers}")
return pt2e_quant_params, quantizers, quant_dtype


def _qmode_type(value):
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
Expand Down
6 changes: 2 additions & 4 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_quantizer_and_quant_params,
)
from executorch.examples.models.llama.export_llama_lib import build_args_parser
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
)
Expand All @@ -44,6 +41,7 @@
HintBasedSymShapeEvalPass,
)
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
from executorch.extension.llm.export.quantizer_lib import get_quantizer_and_quant_params
from executorch.util.activation_memory_profiler import generate_memory_trace
from pytorch_tokenizers.llama2c import Llama2cTokenizer as Tokenizer
from torch.export import Dim
Expand Down
26 changes: 26 additions & 0 deletions extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,29 @@ def get_vulkan_quantizer(pt2e_quantize: str):

quantizer = VulkanQuantizer().set_global(config)
return quantizer


def get_quantizer_and_quant_params(args):
pt2e_quant_params = get_pt2e_quantization_params(
args.pt2e_quantize, args.quantization_mode
)
quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
quant_dtype = None
if args.qnn and args.pt2e_quantize:
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
qnn_quantizer, quant_dtype = get_qnn_quantizer(
args.pt2e_quantize, args.quantization_mode
)
quantizers.append(qnn_quantizer)
if args.coreml and args.pt2e_quantize:
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
quantizers.append(coreml_quantizer)
if args.vulkan and args.pt2e_quantize:
assert (
len(quantizers) == 0
), "Should not enable both vulkan and other quantizers"
vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize)
quantizers.append(vulkan_quantizer)
logging.info(f"Applying quantizers: {quantizers}")
return pt2e_quant_params, quantizers, quant_dtype
Loading