Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch][PTQ] Examples are updated for the new PTQ TORCH backend #2246

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
XFail test is added
  • Loading branch information
daniil-lyakhov committed Nov 8, 2023
commit abe39aabd45404a2ecf3aab5bf6236ba6ba516ac
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /ReadVariable_0" [id=1, type=ReadVariable];
"4 /Conv_0" [id=4, type=Conv];
"6 /Conv2_0" [id=6, type=Conv2];
"7 /Add_0" [id=7, type=Add];
"8 /Final_node_0" [id=8, type=Final_node];
"0 /Input_1_0" -> "4 /Conv_0";
"1 /ReadVariable_0" -> "7 /Add_0";
"6 /Conv2_0" -> "7 /Add_0";
"7 /Add_0" -> "8 /Final_node_0";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /ReadVariable_0" [id=1, type=ReadVariable];
"2 /Weights_0" [id=2, type=Weights];
"3 /AnyNodeBetweenWeightAndConv_0" [id=3, type=AnyNodeBetweenWeightAndConv];
"4 /Conv_0" [id=4, type=Conv];
"5 /Weights2_0" [id=5, type=Weights2];
"6 /Conv2_0" [id=6, type=Conv2];
"7 /Add_0" [id=7, type=Add];
"8 /Final_node_0" [id=8, type=Final_node];
"0 /Input_1_0" -> "4 /Conv_0";
"1 /ReadVariable_0" -> "7 /Add_0";
"2 /Weights_0" -> "3 /AnyNodeBetweenWeightAndConv_0";
"3 /AnyNodeBetweenWeightAndConv_0" -> "4 /Conv_0";
"5 /Weights2_0" -> "6 /Conv2_0";
"6 /Conv2_0" -> "7 /Add_0";
"7 /Add_0" -> "8 /Final_node_0";
}
29 changes: 24 additions & 5 deletions tests/common/quantization/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import pytest

from nncf.quantization.passes import filter_constant_nodes
from nncf.quantization.passes import remove_nodes_and_reconnect_graph
from tests.post_training.test_templates.models import NNCFGraphDropoutRemovingCase
from tests.post_training.test_templates.models import NNCFGraphToTestConstantFiltering
from tests.shared.nx_graph import compare_nx_graph_with_reference
from tests.shared.paths import TEST_ROOT

Expand All @@ -28,13 +30,14 @@ class TestModes(Enum):
WRONG_PARALLEL_EDGES = "wrong_parallel_edges"


def _check_graphs(dot_file_name, nncf_graph) -> None:
nx_graph = nncf_graph.get_graph_for_structure_analysis()
path_to_dot = DATA_ROOT / dot_file_name
compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True)


@pytest.mark.parametrize("mode", [TestModes.VALID, TestModes.WRONG_TENSOR_SHAPE, TestModes.WRONG_PARALLEL_EDGES])
def test_remove_nodes_and_reconnect_graph(mode: TestModes):
def _check_graphs(dot_file_name, nncf_graph) -> None:
nx_graph = nncf_graph.get_graph_for_structure_analysis()
path_to_dot = DATA_ROOT / dot_file_name
compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True)

dot_reference_path_before = Path("passes") / "dropout_synthetic_model_before.dot"
dot_reference_path_after = Path("passes") / "dropout_synthetic_model_after.dot"
dropout_metatype = "DROPOUT_METATYPE"
Expand All @@ -52,3 +55,19 @@ def _check_graphs(dot_file_name, nncf_graph) -> None:
_check_graphs(dot_reference_path_before, nncf_graph)
remove_nodes_and_reconnect_graph(nncf_graph, [dropout_metatype])
_check_graphs(dot_reference_path_after, nncf_graph)


@pytest.mark.xfail
def test_filter_constant_nodes():
dot_reference_path_before = Path("passes") / "test_constant_filtering_model_before.dot"
dot_reference_path_after = Path("passes") / "test_constant_filtering_model_after.dot"

constant_metatype = "CONSTANT_METATYPE"
read_variable_metatype = "READ_VARIABLE_METATYPE"

nncf_graph = NNCFGraphToTestConstantFiltering(constant_metatype, read_variable_metatype).nncf_graph
_check_graphs(dot_reference_path_before, nncf_graph)
filter_constant_nodes(
nncf_graph, read_variable_metatypes=[read_variable_metatype], constant_nodes_metatypes=[constant_metatype]
)
_check_graphs(dot_reference_path_after, nncf_graph)
27 changes: 27 additions & 0 deletions tests/post_training/test_templates/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,30 @@ def __init__(
dtype=Dtype.FLOAT,
parallel_input_port_ids=list(range(1, 10)),
)


class NNCFGraphToTestConstantFiltering:
def __init__(self, constant_metatype, read_variable_metatype, nncf_graph_cls=NNCFGraph) -> None:
nodes = [
NodeWithType("Input_1", InputNoopMetatype),
NodeWithType("Conv", None),
NodeWithType("Weights", constant_metatype),
NodeWithType("AnyNodeBetweenWeightAndConv", None),
NodeWithType("Weights2", constant_metatype),
NodeWithType("Conv2", None),
NodeWithType("ReadVariable", read_variable_metatype),
NodeWithType("Add", None),
NodeWithType("Final_node", None),
]

edges = [
("Input_1", "Conv"),
("Weights", "AnyNodeBetweenWeightAndConv"),
("AnyNodeBetweenWeightAndConv", "Conv"),
("Weights2", "Conv2"),
("Conv2", "Add"),
("ReadVariable", "Add"),
("Add", "Final_node"),
]
original_mock_graph = create_mock_graph(nodes, edges)
self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls)