Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch][PTQ] Examples are updated for the new PTQ TORCH backend #2246

Merged
Prev Previous commit
Next Next commit
Support splitted graph in filter_constant_nodes
  • Loading branch information
daniil-lyakhov committed Nov 8, 2023
commit db0ad8df19da7e4c796e17dbc876c4ea9c9f2e36
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def _get_quantization_target_points(
self._backend_entity.shapeof_metatypes,
self._backend_entity.dropout_metatypes,
self._backend_entity.read_variable_metatypes,
nncf_graph_contains_constants=backend != BackendType.TORCH,
self._backend_entity.constant_metatypes,
)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
Expand Down
7 changes: 7 additions & 0 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific metatypes that also can be interpreted as inputs (ReadValue).
"""

@property
@abstractmethod
def constant_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes that can be interpreted as constants.
"""

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def constant_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConstantMetatype, om.ONNXConstantOfShapeMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def constant_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConstantMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def constant_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
Expand Down
19 changes: 13 additions & 6 deletions nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def transform_to_inference_graph(
shapeof_metatypes: List[OperatorMetatype],
dropout_metatypes: List[OperatorMetatype],
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
nncf_graph_contains_constants: bool = True,
constant_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
Expand All @@ -33,13 +33,13 @@ def transform_to_inference_graph(
:param dropout_metatypes: List of backend-specific Dropout metatypes.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not.
:param constant_metatypes: List of backend-specific metatypes
that can be interpreted as constants.
:return: NNCFGraph in the inference style.
"""
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes)
remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes)
if nncf_graph_contains_constants:
filter_constant_nodes(nncf_graph, read_variable_metatypes)
filter_constant_nodes(nncf_graph, read_variable_metatypes, constant_metatypes)
return nncf_graph


Expand Down Expand Up @@ -143,7 +143,9 @@ def remove_nodes_and_reconnect_graph(


def filter_constant_nodes(
nncf_graph: NNCFGraph, read_variable_metatypes: Optional[List[OperatorMetatype]] = None
nncf_graph: NNCFGraph,
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
constant_nodes_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
Removes all Constant nodes from NNCFGraph inplace, making it inference graph.
Expand All @@ -152,13 +154,18 @@ def filter_constant_nodes(
:param nncf_graph: NNCFGraph instance for the transformation.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param constant_nodes_metatypes: List of backend-specific metatypes
that can be interpreted as constants.
:return: NNCFGraph without Constant nodes.
"""
read_variable_metatypes = read_variable_metatypes if read_variable_metatypes else []
constant_nodes_metatypes = constant_nodes_metatypes if constant_nodes_metatypes else []
input_nodes = nncf_graph.get_input_nodes()
similar_input_nodes = nncf_graph.get_nodes_by_metatypes(read_variable_metatypes)
potential_input_nodes = [node for node in nncf_graph.get_all_nodes() if not nncf_graph.get_input_edges(node)]
potential_input_nodes = [node for node in potential_input_nodes if node.metatype not in constant_nodes_metatypes]

start_nodes = input_nodes + similar_input_nodes
start_nodes = input_nodes + similar_input_nodes + potential_input_nodes

if not start_nodes:
return nncf_graph
Expand Down