Skip to content

Commit

Permalink
Support splitted graph in filter_constant_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 8, 2023
1 parent 5c0ab0b commit 87c8d95
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def _get_quantization_target_points(
self._backend_entity.shapeof_metatypes,
self._backend_entity.dropout_metatypes,
self._backend_entity.read_variable_metatypes,
nncf_graph_contains_constants=backend != BackendType.TORCH,
self._backend_entity.constant_metatypes,
)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
Expand Down
7 changes: 7 additions & 0 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific metatypes that also can be interpreted as inputs (ReadValue).
"""

@property
@abstractmethod
def constant_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes that can be interpreted as constants.
"""

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def constant_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConstantMetatype, om.ONNXConstantOfShapeMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def constant_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConstantMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def constant_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
Expand Down
17 changes: 12 additions & 5 deletions nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def transform_to_inference_graph(
shapeof_metatypes: List[OperatorMetatype],
dropout_metatypes: List[OperatorMetatype],
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
nncf_graph_contains_constants: bool = True,
constant_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
Expand All @@ -33,13 +33,13 @@ def transform_to_inference_graph(
:param dropout_metatypes: List of backend-specific Dropout metatypes.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not.
:param constant_metatypes: List of backend-specific metatypes
that can be interpreted as constants.
:return: NNCFGraph in the inference style.
"""
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes)
remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes)
if nncf_graph_contains_constants:
filter_constant_nodes(nncf_graph, read_variable_metatypes)
filter_constant_nodes(nncf_graph, read_variable_metatypes, constant_metatypes)
return nncf_graph


Expand Down Expand Up @@ -143,7 +143,9 @@ def remove_nodes_and_reconnect_graph(


def filter_constant_nodes(
nncf_graph: NNCFGraph, read_variable_metatypes: Optional[List[OperatorMetatype]] = None
nncf_graph: NNCFGraph,
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
constant_nodes_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
Removes all Constant nodes from NNCFGraph inplace, making it inference graph.
Expand All @@ -152,11 +154,16 @@ def filter_constant_nodes(
:param nncf_graph: NNCFGraph instance for the transformation.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param constant_nodes_metatypes: List of backend-specific metatypes
that can be interpreted as constants.
:return: NNCFGraph without Constant nodes.
"""
read_variable_metatypes = read_variable_metatypes if read_variable_metatypes else []
constant_nodes_metatypes = constant_nodes_metatypes if constant_nodes_metatypes else []
input_nodes = nncf_graph.get_input_nodes()
similar_input_nodes = nncf_graph.get_nodes_by_metatypes(read_variable_metatypes)
potential_input_nodes = [node for node in nncf_graph.get_all_nodes() if not nncf_graph.get_input_edges(node)]
potential_input_nodes = [node for node in potential_input_nodes if node.metatype not in constant_nodes_metatypes]

start_nodes = input_nodes + similar_input_nodes

Expand Down

0 comments on commit 87c8d95

Please sign in to comment.