diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index bfb6f2df78d9e..33758453d4794 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -118,7 +118,7 @@ __all__ = [ "DO_NOT_OBS_DTYPE_LIST", "add_matched_node_name_to_set", - "get_arg_target_compute_dtype_as_input_to_node", + "get_arg_target_is_dynamic_as_input_to_node", "get_arg_target_dtype_as_input_to_node", "get_arg_target_dtype_as_output", "get_target_activation_dtype_for_node", @@ -157,7 +157,7 @@ def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Modu def is_input_arg_dtype_supported_by_backend( arg: Argument, node: Node, - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], dtype_config: DTypeConfig, backend_config: BackendConfig, ) -> bool: @@ -174,27 +174,29 @@ def is_input_arg_dtype_supported_by_backend( is_bias = node_arg_is_bias(node, arg, backend_config) is_activation = not is_weight and not is_bias if is_activation: - is_dynamic = dtype_config.is_dynamic - if is_dynamic: - input_activation_dtype = dtype_config.input_dtype - # TODO: change this after the is_dynamic refactor is landed - compute_dtype = node_name_to_target_dtype[node.name].get("input_activation_compute_dtype", None) - return input_activation_dtype is None or \ - compute_dtype == input_activation_dtype + qconfig_info = node_name_to_target_dtype[node.name].get( + "input_activation_dtype") + if qconfig_info is not None: + qconfig_dtype, qconfig_is_dynamic = qconfig_info else: - input_activation_dtype = dtype_config.input_dtype - return input_activation_dtype is None or \ - node_name_to_target_dtype[node.name]["input_activation_dtype"] == input_activation_dtype + qconfig_dtype, qconfig_is_dynamic = None, None + # TODO(future PR): remove the cast to bool below after figuring + # out why backend_config has is_dynamic set to None in some cases. + return (dtype_config.input_dtype is None) or \ + (dtype_config.input_dtype == qconfig_dtype and + bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic)) elif is_weight: weight_dtype = dtype_config.weight_dtype - return weight_dtype is None or node_name_to_target_dtype[node.name]["weight_dtype"] == weight_dtype + return weight_dtype is None or \ + node_name_to_target_dtype[node.name]["weight_dtype"][0] == weight_dtype # type: ignore[index] else: # bias bias_dtype = dtype_config.bias_dtype - return bias_dtype is None or node_name_to_target_dtype[node.name]["bias_dtype"] == bias_dtype + return bias_dtype is None or \ + node_name_to_target_dtype[node.name]["bias_dtype"][0] == bias_dtype # type: ignore[index] def is_output_dtype_supported_by_backend( node: Node, - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], dtype_config: DTypeConfig, ) -> bool: """ Check if the configured qconfig for the output @@ -202,7 +204,7 @@ def is_output_dtype_supported_by_backend( """ output_dtype = dtype_config.output_dtype return output_dtype is None or \ - output_dtype == node_name_to_target_dtype[node.name]["output_activation_dtype"] + output_dtype == node_name_to_target_dtype[node.name]["output_activation_dtype"][0] # type: ignore[index] def is_observer_in_same_graph(node, modules, node_name_to_target_dtype): """ Check if observer in same graph @@ -219,7 +221,7 @@ def is_observer_in_same_graph(node, modules, node_name_to_target_dtype): def is_pattern_dtype_config_supported_by_backend( pattern: Optional[Pattern], matched_node_pattern: Optional[NodePattern], - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], backend_config: BackendConfig, ) -> bool: """ Check is the dtype configuration of a pattern is supported by @@ -335,11 +337,27 @@ def get_target_activation_dtype_for_node( qhandler: Optional[QuantizeHandler], modules: Dict[str, torch.nn.Module], cache_for_no_tensor_check: Dict[Node, bool], -) -> Dict[str, Optional[Union[torch.dtype, type]]]: +) -> Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]: """ - Returns the expected dtype of the input and output of this node after - convert. If the value is not None, it represents the dtype of the - Tensor. If the value is None, it means the value is not a Tensor. + For each op attribute in the op's input activation, output activation, + weight, bias - returns the settings of dtype and is_dynamic we expect + for the `quantize` call in the reference model representation, or None + if there is no `quantize` call needed. + + For example, if we have a node corresponding to `op0` in + + x0 -> op0 -> x1 + + And we want a reference quantized representation to be + + x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1 + + Then this function will return + + { + 'input_activation': {'dtype': torch.quint8, is_dynamic: False}, + 'output_activation': {'dtype': torch.quint8, is_dynamic: True}, + } Note: this is for activations only, weight dtypes are not handled here. @@ -349,15 +367,15 @@ def get_target_activation_dtype_for_node( if node.op == 'placeholder': if inputs_seen_counter in input_quantized_idxs: return { - "input_activation_dtype": torch.quint8, - "output_activation_dtype": torch.quint8, + "input_activation_dtype": (torch.quint8, False), + "output_activation_dtype": (torch.quint8, False), } else: # if dtype is fp32 (default), do nothing # note: other dtypes are not supported return { - "input_activation_dtype": torch.float, - "output_activation_dtype": torch.float, + "input_activation_dtype": (torch.float, False), + "output_activation_dtype": (torch.float, False), } elif node.op in ('call_module', 'call_method', 'call_function'): @@ -375,8 +393,8 @@ def get_target_activation_dtype_for_node( node.target == operator.getitem if is_getitem: return { - "input_activation_dtype": torch.float, - "output_activation_dtype": torch.float, + "input_activation_dtype": (torch.float, False), + "output_activation_dtype": (torch.float, False), } # get qconfig to determine the eventual dtype of this node @@ -384,39 +402,53 @@ def get_target_activation_dtype_for_node( if qhandler is not None and qhandler.input_output_observed(): act_dtype, weight_dtype, act_compute_dtype = \ get_qconfig_dtypes(qconfig) + input_act_is_dynamic = act_compute_dtype is not None + + # Currently `QConfig` only has one `activation` field. + # For static quantization, it is reused for both input + # and output activation. For dynamic quantization, this + # field is currently only used for the input activation, + # with the output activation being in fp32. + # In the future this may change as we add more fields + # to the `QConfig` object. + output_act_dtype = act_dtype \ + if input_act_is_dynamic is not True else torch.float + bias_dtype = torch.float16 \ - if act_dtype == torch.float16 and weight_dtype == torch.float16 \ - else torch.float + if ( + act_dtype == torch.float16 + and weight_dtype == torch.float16 + and act_compute_dtype is None + ) else torch.float return { - "input_activation_dtype": act_dtype, - "input_activation_compute_dtype": act_compute_dtype, - "weight_dtype": weight_dtype, - "bias_dtype": bias_dtype, - "output_activation_dtype": act_dtype, + "input_activation_dtype": (act_dtype, input_act_is_dynamic), + "weight_dtype": (weight_dtype, False), + "bias_dtype": (bias_dtype, False), + "output_activation_dtype": (output_act_dtype, False), } return { - "input_activation_dtype": torch.float, - "output_activation_dtype": torch.float, + "input_activation_dtype": (torch.float, False), + "output_activation_dtype": (torch.float, False), } elif node.op == 'get_attr': return { - "input_activation_dtype": torch.float, - "output_activation_dtype": torch.float, + "input_activation_dtype": (torch.float, False), + "output_activation_dtype": (torch.float, False), } elif node.op == 'output': if outputs_seen_counter in output_quantized_idxs: return { - "input_activation_dtype": torch.quint8, - "output_activation_dtype": torch.quint8 + "input_activation_dtype": (torch.quint8, False), + "output_activation_dtype": (torch.quint8, False), } else: # if dtype is fp32 (default), do nothing # note: other dtypes are not supported return { - "input_activation_dtype": torch.float, - "output_activation_dtype": torch.float, + "input_activation_dtype": (torch.float, False), + "output_activation_dtype": (torch.float, False), } else: @@ -425,10 +457,10 @@ def get_target_activation_dtype_for_node( def get_arg_target_dtype_as_output( arg: Node, modules: Dict[str, torch.nn.Module], - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], ) -> Optional[Union[torch.dtype, type]]: """ Get the target output activation dtype for - the argumnet in the original graph, skipping inserted observers + the argument in the original graph, skipping inserted observers We are assuming that the observers are inserted correctly, and the dtype for argument in quantized graph will match what is specified by the qconfig """ @@ -436,15 +468,20 @@ def get_arg_target_dtype_as_output( if is_activation_post_process_node(arg, modules): observed_arg = arg.args[0] assert isinstance(observed_arg, Node), "Currently we only support observing Node" - return node_name_to_target_dtype[observed_arg.name]["output_activation_dtype"] + return node_name_to_target_dtype[observed_arg.name]["output_activation_dtype"][0] # type: ignore[index] else: - return node_name_to_target_dtype[arg.name]["output_activation_dtype"] + target_dtype_info = \ + node_name_to_target_dtype[arg.name]["output_activation_dtype"] + if target_dtype_info is not None: + return target_dtype_info[0] + else: + return None def get_arg_target_dtype_as_input_to_node( arg: Node, node: Node, modules: Dict[str, torch.nn.Module], - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], backend_config: BackendConfig, ) -> Optional[Union[torch.dtype, type]]: """ Get the target argument dtype for the argument `arg`, as input @@ -455,22 +492,22 @@ def get_arg_target_dtype_as_input_to_node( is_bias = node_arg_is_bias(node, arg, backend_config) is_activation = not is_weight and not is_bias if is_activation: - return node_name_to_target_dtype[node.name]["input_activation_dtype"] + return node_name_to_target_dtype[node.name]["input_activation_dtype"][0] # type: ignore[index] elif is_weight: if node.target in NON_QUANTIZABLE_WEIGHT_OPS: return None else: - return node_name_to_target_dtype[node.name]["weight_dtype"] + return node_name_to_target_dtype[node.name]["weight_dtype"][0] # type: ignore[index] else: - return node_name_to_target_dtype[node.name]["bias_dtype"] + return node_name_to_target_dtype[node.name]["bias_dtype"][0] # type: ignore[index] -def get_arg_target_compute_dtype_as_input_to_node( +def get_arg_target_is_dynamic_as_input_to_node( arg: Node, node: Node, modules: Dict[str, torch.nn.Module], - node_name_to_target_dtype: Dict[str, Dict[str, Union[torch.dtype, type, None]]], + node_name_to_target_dtype: Dict[str, Dict[str, Tuple[Union[torch.dtype, type, None], bool]]], backend_config: BackendConfig, -) -> Union[torch.dtype, type, None]: +) -> bool: """ Get the target argument dtype for the argument `arg`, as input to node `node` """ @@ -479,10 +516,10 @@ def get_arg_target_compute_dtype_as_input_to_node( is_bias = node_arg_is_bias(node, arg, backend_config) is_activation = not is_weight and not is_bias if is_activation and \ - "input_activation_compute_dtype" in node_name_to_target_dtype[node.name]: - return node_name_to_target_dtype[node.name]["input_activation_compute_dtype"] + "input_activation_dtype" in node_name_to_target_dtype[node.name]: + return node_name_to_target_dtype[node.name]["input_activation_dtype"][1] else: - return None + return False def maybe_insert_input_observer_for_arg_or_kwarg( node: Union[Node, Any], @@ -491,7 +528,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg( model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], qhandler: Optional[QuantizeHandler], prepare_custom_config: PrepareCustomConfig, backend_config: BackendConfig, @@ -537,28 +574,32 @@ def maybe_insert_input_observer_for_arg_or_kwarg( modules, node_name_to_target_dtype, backend_config) - arg_as_input_target_compute_dtype = \ - get_arg_target_compute_dtype_as_input_to_node( - arg, node, modules, node_name_to_target_dtype, backend_config) - needs_obs = ( - # if the dtypes are different, we need an observer - (arg_as_output_target_dtype != arg_as_input_target_dtype) and - # except if the second dtype is float, a dequant will be inserted - # without an observer in convert - # TODO(future PR): change this so a placeholder is inserted for - # future dequants, to make the logic easier to understand - (arg_as_input_target_dtype != torch.float) and - # if arg output dtype is in DO_NOT_OBS_DTYPE_LIST do not insert observer - (arg_as_output_target_dtype not in DO_NOT_OBS_DTYPE_LIST) and - # if qconfig is reuse_input qconfig, we won't insert extra observer for input - not is_reuse_input_qconfig_ or - # need to add input observer for dynamic quantization - # only add observer for first input for now, we may need to extend - # qconfig_dict and backend_config to support more general configurations - # of dynamic quantization, e.g. dynamically quantizing second input, third - # input etc. - (arg_as_input_target_compute_dtype in [torch.quint8, torch.int8, torch.float16]) and arg is node.args[0] - ) + arg_as_input_target_is_dynamic = \ + get_arg_target_is_dynamic_as_input_to_node( + arg, node, modules, node_name_to_target_dtype, backend_config) # type: ignore[arg-type] + needs_obs = \ + ( + # the following code block is for static quantization + (not arg_as_input_target_is_dynamic) and + # if the dtypes are different, we need an observer + (arg_as_output_target_dtype != arg_as_input_target_dtype) and + # except if the second dtype is float, a dequant will be inserted + # without an observer in convert + # TODO(future PR): change this so a placeholder is inserted for + # future dequants, to make the logic easier to understand + (arg_as_input_target_dtype != torch.float) and + # if arg output dtype is in DO_NOT_OBS_DTYPE_LIST do not insert observer + (arg_as_output_target_dtype not in DO_NOT_OBS_DTYPE_LIST) and + # if qconfig is reuse_input qconfig, we won't insert extra observer for input + not is_reuse_input_qconfig_ + ) or ( + # need to add input observer for dynamic quantization + # only add observer for first input for now, we may need to extend + # qconfig_dict and backend_config to support more general configurations + # of dynamic quantization, e.g. dynamically quantizing second input, third + # input etc. + arg_as_input_target_is_dynamic and arg is node.args[0] + ) else: # custom flow for standalone modules @@ -628,7 +669,7 @@ def maybe_insert_input_observers_for_node( model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], qhandler: Optional[QuantizeHandler], prepare_custom_config: PrepareCustomConfig, backend_config: BackendConfig, @@ -683,7 +724,7 @@ def maybe_insert_input_equalization_observers_for_node( model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], is_branch: bool, backend_config: BackendConfig, ) -> None: @@ -728,7 +769,7 @@ def maybe_insert_output_observer_for_node( modules: Dict[str, torch.nn.Module], graph: Graph, matches: Dict[str, _MatchResultWithQConfig], - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], matched_pattern: Any, qhandler: Optional[QuantizeHandler], is_qat: bool, @@ -750,7 +791,7 @@ def maybe_insert_output_observer_for_node( is_standalone_module = qhandler is not None and qhandler.is_standalone_module() - dtype = node_name_to_target_dtype[node.name]["output_activation_dtype"] + dtype, is_dynamic = node_name_to_target_dtype[node.name]["output_activation_dtype"] # type: ignore[misc] should_insert_observer = dtype not in DO_NOT_OBS_DTYPE_LIST + [torch.float] # TODO(future PR): move the following logic to # should_insert_observer_for_output @@ -778,7 +819,7 @@ def maybe_insert_output_observer_for_node( def maybe_insert_observers_before_graph_output( graph_output_node: Node, output_quantized_idxs: List[int], - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], qconfig_map: Dict[str, QConfigAny], model: torch.nn.Module, modules: Dict[str, torch.nn.Module], @@ -807,7 +848,7 @@ def maybe_insert_observers_before_graph_output( def _recursive_maybe_replace_node_with_obs( maybe_node: Argument, target_dtype: torch.dtype, - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], qconfig_map: Dict[str, QConfigAny], model: torch.nn.Module, modules: Dict[str, torch.nn.Module], @@ -879,17 +920,18 @@ def _recursive_maybe_replace_node_with_obs( def maybe_propagate_dtype_for_node( node: Node, target_dtype: Union[torch.dtype, type], - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], matches: Dict[str, _MatchResultWithQConfig], ) -> None: """ - Assigns `target_dtype` to `node`. If `node` is a general tensor shape op + Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node` + is a general tensor shape op (see GeneralTensorShapeOpQuantizeHandler in quantization_patterns.py for more details) also call this function recursively on the first argument, to propagate the dtype to the caller. """ - node_name_to_target_dtype[node.name]["input_activation_dtype"] = target_dtype - node_name_to_target_dtype[node.name]["output_activation_dtype"] = target_dtype + node_name_to_target_dtype[node.name]["input_activation_dtype"] = (target_dtype, False) + node_name_to_target_dtype[node.name]["output_activation_dtype"] = (target_dtype, False) # if this is a copy node, propagate to first arg root_node, _, pattern, qhandler, qconfig = matches.get( node.name, (None, None, None, None, None)) @@ -901,7 +943,7 @@ def maybe_propagate_dtype_for_node( def propagate_dtypes_for_known_nodes( graph: Graph, - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]], + node_name_to_target_dtype: Dict[str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]], matches: Dict[str, _MatchResultWithQConfig], ) -> None: """ @@ -1120,7 +1162,9 @@ def insert_observers_for_model( # } # # TODO: rename this to node_name_to_target_dtype_info - node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]] = defaultdict(dict) + node_name_to_target_dtype: Dict[ + str, Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]] + ] = defaultdict(dict) cache_for_no_tensor_check: Dict[Node, bool] = {} inputs_seen_counter = 0 @@ -1177,8 +1221,8 @@ def insert_observers_for_model( node.name, (None, None, None, None, None)) equalization_qconfig = equalization_config_map.get(node.name, None) - this_node_dtype = node_name_to_target_dtype[node.name] - output_not_a_tensor = this_node_dtype is None + this_node_dtype_info = node_name_to_target_dtype[node.name] + output_not_a_tensor = this_node_dtype_info is None # TODO(future PR): consider stopping matching getitem is_getitem = node.op == 'call_function' and \ node.target == operator.getitem diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 087c8983330c2..dd226df9a7650 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -245,9 +245,11 @@ def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[ qconfig_activation_dtype, qconfig_weight_dtype, qconfig_compute_dtype = \ get_qconfig_dtypes(qconfig) qconfig_bias_dtype = torch.float16 \ - if qconfig_activation_dtype == torch.float16 and \ - qconfig_weight_dtype == torch.float16 \ - else torch.float + if ( + qconfig_activation_dtype == torch.float16 + and qconfig_weight_dtype == torch.float16 + and not is_dynamic + ) else torch.float if is_dynamic: is_match = input_dtype == qconfig_compute_dtype and \ diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index faadf7771c9ce..81e05dbe9bbbd 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -158,7 +158,8 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[ if hasattr(activation_post_process, "compute_dtype"): compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] quantize_op : Optional[Union[Callable, str]] = None - if dtype in [torch.quint8, torch.qint8]: + if dtype in [torch.quint8, torch.qint8] and \ + not hasattr(activation_post_process, 'compute_dtype'): node_type = "call_function" scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] @@ -170,11 +171,8 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[ zero_point = int(zero_point) qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} quantize_op = torch.quantize_per_tensor - elif dtype == torch.float16: - node_type = "call_method" - quantize_op = "to" - qparams = {"_dtype_": dtype} - elif dtype == torch.float32 and compute_dtype in [torch.quint8, torch.qint8, torch.float16]: + elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]: + # TODO(future PR): switch compute_dtype to is_dynamic # dynamic quantization node_type = "call_function" quantize_op = torch.quantize_per_tensor_dynamic @@ -182,6 +180,10 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[ # reduce_range = activation_post_process.reduce_range reduce_range = torch.backends.quantized.engine == "fbgemm" qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range} + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" + qparams = {"_dtype_": dtype} else: warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}") return None diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 3781379faffd7..9c4090c3ef1ac 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -131,7 +131,8 @@ class ObserverBase(ABC, nn.Module): the collected statistics. Args: - dtype: Quantized data type + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. """ def __init__(self, dtype): @@ -155,7 +156,8 @@ class UniformQuantizationObserverBase(ObserverBase): scale and zero_point. Args: - dtype: Quantized data type. + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. qscheme: Quantization scheme to be used. reduce_range: Reduces the range of the quantized data type by 1 bit. This is sometimes required to avoid instruction overflow. @@ -382,7 +384,8 @@ class MinMaxObserver(UniformQuantizationObserverBase): tensors, and uses this statistic to compute the quantization parameters. Args: - dtype: Quantized data type + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. qscheme: Quantization scheme to be used reduce_range: Reduces the range of the quantized data type by 1 bit quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. @@ -520,7 +523,8 @@ class MovingAverageMinMaxObserver(MinMaxObserver): Args: averaging_constant: Averaging constant for min/max. - dtype: Quantized data type + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. qscheme: Quantization scheme to be used reduce_range: Reduces the range of the quantized data type by 1 bit quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. @@ -610,7 +614,8 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase): Args: ch_axis: Channel axis - dtype: Quantized data type + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. qscheme: Quantization scheme to be used reduce_range: Reduces the range of the quantized data type by 1 bit quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. @@ -884,7 +889,8 @@ class HistogramObserver(UniformQuantizationObserverBase): bins: Number of bins to use for the histogram upsample_rate: Factor by which the histograms are upsampled, this is used to interpolate histograms with varying ranges across observations - dtype: Quantized data type + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec qscheme: Quantization scheme to be used reduce_range: Reduces the range of the quantized data type by 1 bit eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. @@ -1308,9 +1314,14 @@ class PlaceholderObserver(ObserverBase): ranges. Args: - dtype: Quantized data type + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation (Can be used in Graph Mode Passes for special case ops). + compute_dtype: if set, marks the future quantize function to use + dynamic quantization instead of static quantization. + Note: this field will be removed in the near future and + replaced with `is_dynamic`. """ def __init__( @@ -1322,6 +1333,7 @@ def __init__( self.dtype = dtype self.custom_op = custom_op_name # used for configuration of computation type for dynamic quantization + # TODO(future PR): replace this with `is_dynamic` if compute_dtype: self.compute_dtype = compute_dtype @@ -1539,7 +1551,7 @@ def load_observer_state_dict(mod, obs_dict): """ default_dynamic_quant_observer = PlaceholderObserver.with_args( - dtype=torch.float, compute_dtype=torch.quint8 + dtype=torch.quint8, compute_dtype=torch.quint8 ) """ Default observer for dynamic quantization. diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index a4766ee8bdaee..17310fd3aec17 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -157,7 +157,7 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): Default dynamic qconfig. """ -float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float32, compute_dtype=torch.float16), +float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, compute_dtype=torch.float16), weight=PlaceholderObserver.with_args(dtype=torch.float16)) """ Dynamic qconfig with weights quantized to `torch.float16`. diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index 012718a55a04e..5d3af232a8883 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -187,7 +187,10 @@ def activation_is_statically_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be quantized or not, this includes quantizing to quint8, qint8 and float16 """ - return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16] + return ( + activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16] + and (not activation_is_dynamically_quantized(qconfig)) + ) def activation_is_dynamically_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be @@ -196,8 +199,7 @@ def activation_is_dynamically_quantized(qconfig): """ activation_dtype, _, activation_compute_dtype = \ get_qconfig_dtypes(qconfig) - return activation_dtype == torch.float and \ - activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16] + return activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16] def activation_is_int8_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be @@ -230,10 +232,11 @@ def op_is_int8_dynamically_quantized(qconfig) -> bool: activation_dtype, weight_dtype, activation_compute_dtype = \ get_qconfig_dtypes(qconfig) return ( - activation_dtype is torch.float and + activation_dtype is torch.quint8 and # for now, the lines below assume fbgemm or qnnpack weight_dtype is torch.qint8 and activation_compute_dtype is torch.quint8 + # TODO(future PR): add is_dynamic ) def get_qconfig_dtypes(qconfig): @@ -252,15 +255,15 @@ def get_quant_type(qconfig): weight = qconfig.weight() static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2] if weight.dtype in static_dtypes: - if activation.dtype in static_dtypes: - return QuantType.STATIC - elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: + if hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: return QuantType.DYNAMIC + elif activation.dtype in static_dtypes: + return QuantType.STATIC else: return QuantType.WEIGHT_ONLY if weight.dtype == torch.float16: - if activation.dtype == torch.float: + if hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: return QuantType.DYNAMIC elif activation.dtype == torch.float16: return QuantType.STATIC