From caadd81e65bf1240479250124814bc625a13b50c Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Wed, 14 Aug 2024 23:22:35 -0700 Subject: [PATCH] VulkanQuantizer for weight-only quantization on linear Differential Revision: D61243540 Pull Request resolved: https://github.com/pytorch/executorch/pull/4707 --- backends/vulkan/quantizer/TARGETS | 13 ++ backends/vulkan/quantizer/vulkan_quantizer.py | 120 ++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 backends/vulkan/quantizer/TARGETS create mode 100644 backends/vulkan/quantizer/vulkan_quantizer.py diff --git a/backends/vulkan/quantizer/TARGETS b/backends/vulkan/quantizer/TARGETS new file mode 100644 index 0000000000..7cc5b79eb2 --- /dev/null +++ b/backends/vulkan/quantizer/TARGETS @@ -0,0 +1,13 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "vulkan_quantizer", + srcs = [ + "vulkan_quantizer.py", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py new file mode 100644 index 0000000000..451f18977e --- /dev/null +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +import functools +from typing import Any, Callable, Dict, Optional + +import torch +from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _convert_scalars_to_attrs, + OP_TO_ANNOTATOR, + propagate_annotation, + QuantizationConfig, +) +from torch.fx import Node + + +__all__ = [ + "VulkanQuantizer", + "get_weight_quantization_config", +] + + +@functools.lru_cache +def get_weight_quantization_config( + is_per_channel: bool = True, + weight_qmin: int = -128, + weight_qmax: int = 127, +) -> QuantizationConfig: + + weight_qscheme = ( + torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + PerChannelMinMaxObserver if is_per_channel else MinMaxObserver + ) + extra_args: Dict[str, Any] = {"eps": 2**-12} + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=weight_qscheme, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + quantization_config = QuantizationConfig( + input_activation=None, + output_activation=None, + weight=weight_quantization_spec, + bias=None, + is_qat=False, + ) + return quantization_config + + +_SUPPORTED_OPS = [ + "linear", +] + + +class VulkanQuantizer(Quantizer): + + def __init__(self) -> None: + super().__init__() + self.global_config: Optional[QuantizationConfig] = None + + def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer: + self.global_config = quantization_config + return self + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Transforms scalar values to tensor attributes""" + return _convert_scalars_to_attrs(model) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + # currently only support static quant on Vulkan + model = self._annotate_for_static_quantization_config(model) + propagate_annotation(model) + return model + + def _annotate_all_static_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + if quantization_config is None: + return model + + for op in _SUPPORTED_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + self._annotate_all_static_patterns( + model, + self.global_config, + ) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass