Skip to content

Commit

Permalink
VulkanQuantizer for weight-only quantization on linear
Browse files Browse the repository at this point in the history
Differential Revision: D61243540

Pull Request resolved: pytorch#4707
  • Loading branch information
nathanaelsee authored Aug 15, 2024
1 parent 35da5bf commit caadd81
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 0 deletions.
13 changes: 13 additions & 0 deletions backends/vulkan/quantizer/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "vulkan_quantizer",
srcs = [
"vulkan_quantizer.py",
],
deps = [
"//caffe2:torch",
],
)
120 changes: 120 additions & 0 deletions backends/vulkan/quantizer/vulkan_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import functools
from typing import Any, Callable, Dict, Optional

import torch
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
propagate_annotation,
QuantizationConfig,
)
from torch.fx import Node


__all__ = [
"VulkanQuantizer",
"get_weight_quantization_config",
]


@functools.lru_cache
def get_weight_quantization_config(
is_per_channel: bool = True,
weight_qmin: int = -128,
weight_qmax: int = 127,
) -> QuantizationConfig:

weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
)
extra_args: Dict[str, Any] = {"eps": 2**-12}

weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=weight_qmin,
quant_max=weight_qmax,
qscheme=weight_qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)

quantization_config = QuantizationConfig(
input_activation=None,
output_activation=None,
weight=weight_quantization_spec,
bias=None,
is_qat=False,
)
return quantization_config


_SUPPORTED_OPS = [
"linear",
]


class VulkanQuantizer(Quantizer):

def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None

def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer:
self.global_config = quantization_config
return self

def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""Transforms scalar values to tensor attributes"""
return _convert_scalars_to_attrs(model)

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
# currently only support static quant on Vulkan
model = self._annotate_for_static_quantization_config(model)
propagate_annotation(model)
return model

def _annotate_all_static_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
if quantization_config is None:
return model

for op in _SUPPORTED_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model

def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
self._annotate_all_static_patterns(
model,
self.global_config,
)
return model

def validate(self, model: torch.fx.GraphModule) -> None:
pass

0 comments on commit caadd81

Please sign in to comment.