From 87c8d952fafa0b6ce539feed8311384c470bc80c Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 8 Nov 2023 14:10:08 +0100 Subject: [PATCH] Support splitted graph in filter_constant_nodes --- .../algorithms/min_max/algorithm.py | 2 +- nncf/quantization/algorithms/min_max/backend.py | 7 +++++++ .../algorithms/min_max/onnx_backend.py | 4 ++++ .../algorithms/min_max/openvino_backend.py | 4 ++++ .../algorithms/min_max/torch_backend.py | 4 ++++ nncf/quantization/passes.py | 17 ++++++++++++----- 6 files changed, 32 insertions(+), 6 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index e698c77946b..0191ea90423 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -540,7 +540,7 @@ def _get_quantization_target_points( self._backend_entity.shapeof_metatypes, self._backend_entity.dropout_metatypes, self._backend_entity.read_variable_metatypes, - nncf_graph_contains_constants=backend != BackendType.TORCH, + self._backend_entity.constant_metatypes, ) quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index 2c105309c8e..ad3f6d8acf9 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -74,6 +74,13 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]: Property for the backend-specific metatypes that also can be interpreted as inputs (ReadValue). """ + @property + @abstractmethod + def constant_metatypes(self) -> List[OperatorMetatype]: + """ + Property for the backend-specific metatypes that can be interpreted as constants. + """ + @property @abstractmethod def overflow_fix_metatypes(self) -> List[OperatorMetatype]: diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index bdd04bb22c3..ed9cae28b3f 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -80,6 +80,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]: def read_variable_metatypes(self) -> List[OperatorMetatype]: return [] + @property + def constant_metatypes(self) -> List[OperatorMetatype]: + return [om.ONNXConstantMetatype, om.ONNXConstantOfShapeMetatype] + @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: return {om.ONNXConcatMetatype: self.overflow_fix_metatypes} diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 60498893999..a44a5c7127e 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -86,6 +86,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]: def read_variable_metatypes(self) -> List[OperatorMetatype]: return [om.OVReadValueMetatype] + @property + def constant_metatypes(self) -> List[OperatorMetatype]: + return [om.OVConstantMetatype] + @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: return {om.OVConcatMetatype: self.overflow_fix_metatypes} diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index e6dc15e46a9..d500fc8cc77 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -73,6 +73,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]: def read_variable_metatypes(self) -> List[OperatorMetatype]: return [] + @property + def constant_metatypes(self) -> List[OperatorMetatype]: + return [] + @property def conv_metatypes(self) -> List[OperatorMetatype]: return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype] diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index 055f1f27a5b..4d303eac0f7 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -23,7 +23,7 @@ def transform_to_inference_graph( shapeof_metatypes: List[OperatorMetatype], dropout_metatypes: List[OperatorMetatype], read_variable_metatypes: Optional[List[OperatorMetatype]] = None, - nncf_graph_contains_constants: bool = True, + constant_metatypes: Optional[List[OperatorMetatype]] = None, ) -> NNCFGraph: """ This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows. @@ -33,13 +33,13 @@ def transform_to_inference_graph( :param dropout_metatypes: List of backend-specific Dropout metatypes. :param read_variable_metatypes: List of backend-specific metatypes that also can be interpreted as inputs (ReadValue). - :param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not. + :param constant_metatypes: List of backend-specific metatypes + that can be interpreted as constants. :return: NNCFGraph in the inference style. """ remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes) remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes) - if nncf_graph_contains_constants: - filter_constant_nodes(nncf_graph, read_variable_metatypes) + filter_constant_nodes(nncf_graph, read_variable_metatypes, constant_metatypes) return nncf_graph @@ -143,7 +143,9 @@ def remove_nodes_and_reconnect_graph( def filter_constant_nodes( - nncf_graph: NNCFGraph, read_variable_metatypes: Optional[List[OperatorMetatype]] = None + nncf_graph: NNCFGraph, + read_variable_metatypes: Optional[List[OperatorMetatype]] = None, + constant_nodes_metatypes: Optional[List[OperatorMetatype]] = None, ) -> NNCFGraph: """ Removes all Constant nodes from NNCFGraph inplace, making it inference graph. @@ -152,11 +154,16 @@ def filter_constant_nodes( :param nncf_graph: NNCFGraph instance for the transformation. :param read_variable_metatypes: List of backend-specific metatypes that also can be interpreted as inputs (ReadValue). + :param constant_nodes_metatypes: List of backend-specific metatypes + that can be interpreted as constants. :return: NNCFGraph without Constant nodes. """ read_variable_metatypes = read_variable_metatypes if read_variable_metatypes else [] + constant_nodes_metatypes = constant_nodes_metatypes if constant_nodes_metatypes else [] input_nodes = nncf_graph.get_input_nodes() similar_input_nodes = nncf_graph.get_nodes_by_metatypes(read_variable_metatypes) + potential_input_nodes = [node for node in nncf_graph.get_all_nodes() if not nncf_graph.get_input_edges(node)] + potential_input_nodes = [node for node in potential_input_nodes if node.metatype not in constant_nodes_metatypes] start_nodes = input_nodes + similar_input_nodes