Skip to content

Commit

Permalink
quantization: align observer dtype with reference model spec (pytorch…
Browse files Browse the repository at this point in the history
…#85345)

Summary:

Before this PR, the `dtype` attribute of observers was not clearly
defined.  It originally meant `interface_dtype` in the eager mode
workflow, which is how the codebase before this PR is using it.

In the new reference model spec, `dtype` attribute of an observer
represents the `dtype` value which needs to be passed into a `quantize`
function in the reference model spec. This PR aligns the codebase
to this definition of dtype.  In detail:
1. change util functions to interpret `dtype` using the reference model definition
2. change `prepare` to interpret `dtype` using the reference model definition
3. change observers for dynamic quantization to interpret `dtype` using the reference
   model definition.

A future PR (left out of this one to keep LOC small) will deprecate the
`compute_dtype` field and instead expose `is_dynamic` on observers.
"

Test plan:

```
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
```

Differential Revision: [D39675209](https://our.internmc.facebook.com/intern/diff/D39675209)
Pull Request resolved: pytorch#85345
Approved by: https://github.com/z-a-f, https://github.com/jerryzh168
  • Loading branch information
vkuzo authored and pytorchmergebot committed Sep 21, 2022
1 parent 08f413b commit 0996595
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 119 deletions.
230 changes: 137 additions & 93 deletions torch/ao/quantization/fx/prepare.py

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions torch/ao/quantization/fx/qconfig_mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,11 @@ def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[
qconfig_activation_dtype, qconfig_weight_dtype, qconfig_compute_dtype = \
get_qconfig_dtypes(qconfig)
qconfig_bias_dtype = torch.float16 \
if qconfig_activation_dtype == torch.float16 and \
qconfig_weight_dtype == torch.float16 \
else torch.float
if (
qconfig_activation_dtype == torch.float16
and qconfig_weight_dtype == torch.float16
and not is_dynamic
) else torch.float

if is_dynamic:
is_match = input_dtype == qconfig_compute_dtype and \
Expand Down
14 changes: 8 additions & 6 deletions torch/ao/quantization/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
if hasattr(activation_post_process, "compute_dtype"):
compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined]
quantize_op : Optional[Union[Callable, str]] = None
if dtype in [torch.quint8, torch.qint8]:
if dtype in [torch.quint8, torch.qint8] and \
not hasattr(activation_post_process, 'compute_dtype'):
node_type = "call_function"
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined]
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
Expand All @@ -170,18 +171,19 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
zero_point = int(zero_point)
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
quantize_op = torch.quantize_per_tensor
elif dtype == torch.float16:
node_type = "call_method"
quantize_op = "to"
qparams = {"_dtype_": dtype}
elif dtype == torch.float32 and compute_dtype in [torch.quint8, torch.qint8, torch.float16]:
elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]:
# TODO(future PR): switch compute_dtype to is_dynamic
# dynamic quantization
node_type = "call_function"
quantize_op = torch.quantize_per_tensor_dynamic
# TODO: get reduce range from observer
# reduce_range = activation_post_process.reduce_range
reduce_range = torch.backends.quantized.engine == "fbgemm"
qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range}
elif dtype == torch.float16:
node_type = "call_method"
quantize_op = "to"
qparams = {"_dtype_": dtype}
else:
warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}")
return None
Expand Down
28 changes: 20 additions & 8 deletions torch/ao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ class ObserverBase(ABC, nn.Module):
the collected statistics.
Args:
dtype: Quantized data type
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
"""

def __init__(self, dtype):
Expand All @@ -155,7 +156,8 @@ class UniformQuantizationObserverBase(ObserverBase):
scale and zero_point.
Args:
dtype: Quantized data type.
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
qscheme: Quantization scheme to be used.
reduce_range: Reduces the range of the quantized data type by 1 bit.
This is sometimes required to avoid instruction overflow.
Expand Down Expand Up @@ -382,7 +384,8 @@ class MinMaxObserver(UniformQuantizationObserverBase):
tensors, and uses this statistic to compute the quantization parameters.
Args:
dtype: Quantized data type
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
Expand Down Expand Up @@ -520,7 +523,8 @@ class MovingAverageMinMaxObserver(MinMaxObserver):
Args:
averaging_constant: Averaging constant for min/max.
dtype: Quantized data type
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
Expand Down Expand Up @@ -610,7 +614,8 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
Args:
ch_axis: Channel axis
dtype: Quantized data type
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
Expand Down Expand Up @@ -884,7 +889,8 @@ class HistogramObserver(UniformQuantizationObserverBase):
bins: Number of bins to use for the histogram
upsample_rate: Factor by which the histograms are upsampled, this is
used to interpolate histograms with varying ranges across observations
dtype: Quantized data type
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
Expand Down Expand Up @@ -1308,9 +1314,14 @@ class PlaceholderObserver(ObserverBase):
ranges.
Args:
dtype: Quantized data type
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
(Can be used in Graph Mode Passes for special case ops).
compute_dtype: if set, marks the future quantize function to use
dynamic quantization instead of static quantization.
Note: this field will be removed in the near future and
replaced with `is_dynamic`.
"""

def __init__(
Expand All @@ -1322,6 +1333,7 @@ def __init__(
self.dtype = dtype
self.custom_op = custom_op_name
# used for configuration of computation type for dynamic quantization
# TODO(future PR): replace this with `is_dynamic`
if compute_dtype:
self.compute_dtype = compute_dtype

Expand Down Expand Up @@ -1539,7 +1551,7 @@ def load_observer_state_dict(mod, obs_dict):
"""

default_dynamic_quant_observer = PlaceholderObserver.with_args(
dtype=torch.float, compute_dtype=torch.quint8
dtype=torch.quint8, compute_dtype=torch.quint8
)
"""
Default observer for dynamic quantization.
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
Default dynamic qconfig.
"""

float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float32, compute_dtype=torch.float16),
float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, compute_dtype=torch.float16),
weight=PlaceholderObserver.with_args(dtype=torch.float16))
"""
Dynamic qconfig with weights quantized to `torch.float16`.
Expand Down
19 changes: 11 additions & 8 deletions torch/ao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ def activation_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized or not, this includes quantizing to quint8, qint8 and float16
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
return (
activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
and (not activation_is_dynamically_quantized(qconfig))
)

def activation_is_dynamically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
Expand All @@ -196,8 +199,7 @@ def activation_is_dynamically_quantized(qconfig):
"""
activation_dtype, _, activation_compute_dtype = \
get_qconfig_dtypes(qconfig)
return activation_dtype == torch.float and \
activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16]
return activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16]

def activation_is_int8_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
Expand Down Expand Up @@ -230,10 +232,11 @@ def op_is_int8_dynamically_quantized(qconfig) -> bool:
activation_dtype, weight_dtype, activation_compute_dtype = \
get_qconfig_dtypes(qconfig)
return (
activation_dtype is torch.float and
activation_dtype is torch.quint8 and
# for now, the lines below assume fbgemm or qnnpack
weight_dtype is torch.qint8 and
activation_compute_dtype is torch.quint8
# TODO(future PR): add is_dynamic
)

def get_qconfig_dtypes(qconfig):
Expand All @@ -252,15 +255,15 @@ def get_quant_type(qconfig):
weight = qconfig.weight()
static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2]
if weight.dtype in static_dtypes:
if activation.dtype in static_dtypes:
return QuantType.STATIC
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
if hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
return QuantType.DYNAMIC
elif activation.dtype in static_dtypes:
return QuantType.STATIC
else:
return QuantType.WEIGHT_ONLY

if weight.dtype == torch.float16:
if activation.dtype == torch.float:
if hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
return QuantType.DYNAMIC
elif activation.dtype == torch.float16:
return QuantType.STATIC
Expand Down

0 comments on commit 0996595

Please sign in to comment.