Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding to unity substitution set #1594

Open
wants to merge 93 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
510c2d9
Start on pcg builder
lockshaw Jun 4, 2024
7b55ed1
Add tests and some implementation for pcg builder
lockshaw Jun 4, 2024
c379efd
Add pcg tests, make dtgen constructors explicit to fix bug
lockshaw Jun 10, 2024
35fa653
Add remainder of PCG tests
lockshaw Jun 10, 2024
865a28e
Merge remote-tracking branch 'origin/repo-refactor' into pcg-builder
lockshaw Jun 10, 2024
f379539
Fix build issues in local-execution
lockshaw Jun 10, 2024
2dbb3b9
Format
lockshaw Jun 10, 2024
4050c99
Address Reyna comments, add topological_order function for PCG
lockshaw Jun 17, 2024
42c1968
Pre multidigraph refactor
lockshaw Jun 19, 2024
3be816f
Removing visitable from sp code
lockshaw Jun 21, 2024
6d68324
Add open dataflow graph, start to replace pcg dataflow graph
lockshaw Jun 23, 2024
64a3403
Start refactoring substitutions
lockshaw Jun 24, 2024
7d4c7be
Add utility functions to support pattern matching
lockshaw Jun 25, 2024
9ab9eb2
Pre-refactor inputs
lockshaw Jun 26, 2024
7ae7c65
Merge remote-tracking branch 'origin/repo-refactor' into dataflow-graph
lockshaw Jun 26, 2024
f9b129e
Fix proj url
lockshaw Jun 26, 2024
cf73f08
Get back to substitutions, now with unordered graph inputs
lockshaw Jul 7, 2024
5fd666d
Get substitutions building
lockshaw Jul 13, 2024
5f0c88a
substitutions-tests now builds
lockshaw Jul 13, 2024
3228f2d
Fix bug in filter, pass some initial substitution tests
lockshaw Jul 14, 2024
5f4cc01
Add tests for fmt::to_string, fix some substitutions bugs
lockshaw Jul 15, 2024
ad60be0
Pass initial unit tests for find_pattern_matches
lockshaw Jul 15, 2024
a972da2
Start on unit tests for pcg pattern
lockshaw Jul 15, 2024
bcf776e
Pass initial test for find_pattern_matches
lockshaw Jul 19, 2024
e28400e
Merge remote-tracking branch 'origin/repo-refactor' into dataflow-graph
lockshaw Jul 19, 2024
fe6d65d
Fix small build issue in tests
lockshaw Jul 19, 2024
e647af7
Format
lockshaw Jul 19, 2024
8b58760
Sync tests in CI with tests in proj
lockshaw Jul 19, 2024
1fafb9d
Fix minor build errors in kernels and local-execution
lockshaw Jul 19, 2024
0804314
Format
lockshaw Jul 19, 2024
dd5465c
Remove outdated code
lockshaw Jul 20, 2024
29ec5b8
More outdated code removal
lockshaw Jul 20, 2024
ff41743
More cleanup, add test for sp decomposition
lockshaw Jul 20, 2024
e71d200
Pull apart containers.h
lockshaw Jul 21, 2024
c06710c
More sp testing and fixes
lockshaw Jul 21, 2024
2f75566
Break up graph algorithms.h
lockshaw Jul 21, 2024
c81d3a4
Pre- full SP algo commit
lockshaw Jul 23, 2024
2a11c7e
Add initial implementation and tests for cbc decomposition and invers…
lockshaw Jul 23, 2024
71a9e0f
Pass test for get_inverse_line_graph
lockshaw Jul 24, 2024
25eb1db
Add new multidigraph
lockshaw Jul 24, 2024
64f1932
Fix get_inverse_line_graph to return a MultiDiGraph instead of a DiGraph
lockshaw Jul 24, 2024
31c8d17
Add tests for parallel and series reduction finding
lockshaw Jul 24, 2024
19e7e28
Add really rough implementation of valdez sp decomposition
lockshaw Jul 24, 2024
3791e86
Fix local-execution build
lockshaw Jul 25, 2024
267b72d
Add implementations and tests for applying series/parallel reductions
lockshaw Jul 25, 2024
bb2769a
Format
lockshaw Jul 26, 2024
39cb7b3
Clean up sp decomposition interface and tests
lockshaw Jul 27, 2024
ce0234d
Format
lockshaw Jul 27, 2024
3dc3ec6
Add comments for top-level substitutions functions, add proj doxygen …
lockshaw Jul 31, 2024
ee518c2
Start sketching out substitutions code
lockshaw Jul 31, 2024
f69b95a
Merge branch 'dataflow-graph' into substitutions-fix
lockshaw Jul 31, 2024
3c06b88
Fix build errors
lockshaw Aug 1, 2024
3d6f681
Add ability to permute node ids
lockshaw Aug 1, 2024
098a9d1
Cleanup and start to test new substitutions code
lockshaw Aug 4, 2024
9bd4f14
Add test case for evaluate_substitution_output
lockshaw Aug 5, 2024
101083b
Add naive isomorphism detection code
lockshaw Aug 5, 2024
9fec50c
Add graph inputs to open dataflow graph isomorphism
lockshaw Aug 6, 2024
7c60736
Add input permutation to evaluate_substitution_output
lockshaw Aug 6, 2024
cb6eab2
Fix permute_node_ids
lockshaw Aug 8, 2024
2f3d67a
Add test for permute_input_ids
lockshaw Aug 8, 2024
03cbd02
Migrate over to mutable implementation of apply_substitution
lockshaw Aug 23, 2024
4a8deae
Add fast isomorphism checking and an initial implementation of full s…
lockshaw Aug 24, 2024
0757e94
Pass initial full substitutions test
lockshaw Aug 24, 2024
ba0a174
Cleanup old isomorphism checking code
lockshaw Aug 24, 2024
4dfa403
Merge remote-tracking branch 'origin/repo-refactor' into substitution…
lockshaw Aug 24, 2024
f156f96
Fix post-merge bugs
lockshaw Aug 24, 2024
5f09298
Fix broken pcg builder test
lockshaw Aug 26, 2024
deff4f8
Format
lockshaw Aug 26, 2024
d71d24f
Reorganize code and remove some outdated code pre-code-review
lockshaw Aug 26, 2024
1a63f90
Format
lockshaw Aug 26, 2024
1d4ab09
Restarting work on this after working on export-model-arch
lockshaw Sep 10, 2024
7928864
Merge remote-tracking branch 'flexflow/repo-refactor' into substituti…
lockshaw Sep 10, 2024
f2c8e7b
Merge branch 'substitution-builder' into master
Jan 9, 2025
030b0e8
Merge branch 'master' into merge-substitution-builder
victorli2002 Jan 15, 2025
f5c49c7
Adding in some a simple function to get the currently available subst…
Jan 17, 2025
3e4c357
Merge branch 'master' into merge-substitution-builder
victorli2002 Jan 17, 2025
30f2b6e
Merge branch 'master' into merge-substitution-builder
lockshaw Jan 20, 2025
47cc58a
nonnegative_int additions, code cleanup, etc.
Jan 24, 2025
f8df37e
Merge remote-tracking branch 'origin/master' into victor-substitution…
lockshaw Jan 25, 2025
3728251
A bunch more moving over to nonnegative_int
lockshaw Jan 28, 2025
f27d31b
Even more nonnegative_int updating
lockshaw Jan 31, 2025
9f8762e
Fix build
lockshaw Jan 31, 2025
5edb6f0
Fix failing tests
lockshaw Feb 1, 2025
97338c7
Merge branch 'master' into merge-substitution-builder
lockshaw Feb 1, 2025
88370c0
Format
lockshaw Feb 1, 2025
9c2007e
Merge remote-tracking branch 'refs/remotes/victorli2002/merge-substit…
lockshaw Feb 1, 2025
600e074
Format
lockshaw Feb 1, 2025
0dd487c
Adding more to unity substitution set
Feb 7, 2025
1c18cdd
Merge branch 'master' into merge-substitution-builder
victorli2002 Feb 7, 2025
12f8717
Fixing unity substituion tests
Feb 10, 2025
d5c9be9
Adding more to unity substitution set
Feb 21, 2025
1394173
Merge remote-tracking branch 'origin/master' into merge-substitution-…
Feb 21, 2025
4fe85ed
Updating unity substitution set to work with new PCG interface
Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Pass initial test for find_pattern_matches
  • Loading branch information
lockshaw committed Jul 19, 2024
commit bcf776e56fb3862c239242db8192f985faaae3e9
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "local-execution/variadic_tensor_ref.h"
#include "op-attrs/computation_graph_op_attrs.h"
#include "pcg/computation_graph.h"
#include "utils/bidict.h"
#include "utils/bidict/bidict.h"
#include "utils/stack_map.h"
#include <typeindex>
#include <unordered_map>
Expand Down
1 change: 1 addition & 0 deletions lib/substitutions/include/substitutions/pcg_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ TensorAttributePattern get_tensor_pattern(PCGPattern const &,
PatternValue const &);
OperatorAttributePattern get_operator_pattern(PCGPattern const &,
PatternNode const &);
std::unordered_set<PatternInput> get_inputs(PCGPattern const &);

bool assignment_satisfies(SubParallelComputationGraph const &,
PCGPattern const &,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ features = [
]

includes = [
"utils/bidict.h",
"utils/bidict/bidict.h",
"utils/graph.h",
"<utility>",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ includes = [
"substitutions/unlabelled/unlabelled_graph_pattern.dtg.h",
"substitutions/unlabelled/pattern_value.dtg.h",
"substitutions/unlabelled/pattern_input.dtg.h",
"utils/bidict.h",
"utils/bidict/bidict.h",
]

[[fields]]
Expand All @@ -18,5 +18,9 @@ name = "subpattern_2"
type = "::FlexFlow::UnlabelledGraphPattern"

[[fields]]
name = "subpattern_1_outputs_to_subpattern_2_inputs"
name = "full_pattern_values_to_subpattern_1_inputs"
type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>"

[[fields]]
name = "full_pattern_values_to_subpattern_2_inputs"
type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>"
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ UnlabelledDataflowGraphPatternMatch empty_unlabelled_pattern_match();
std::unordered_set<Node> matched_nodes(UnlabelledDataflowGraphPatternMatch const &);
std::optional<UnlabelledDataflowGraphPatternMatch> merge_unlabelled_dataflow_graph_pattern_matches(UnlabelledDataflowGraphPatternMatch const &subpattern_1,
UnlabelledDataflowGraphPatternMatch const &subpattern_2,
bidict<PatternValue, PatternInput> const &outputs_of_1_to_inputs_of_2);
bidict<PatternValue, PatternInput> const &merged_graph_values_to_inputs_of_1,
bidict<PatternValue, PatternInput> const &merged_graph_values_to_inputs_of_2);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ features = [
]

includes = [
"utils/bidict.h",
"utils/bidict/bidict.h",
"utils/graph/node/node.dtg.h",
"utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h",
"substitutions/unlabelled/pattern_input.dtg.h",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "substitutions/unlabelled/pattern_node.dtg.h"
#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h"
#include "substitutions/unlabelled/pattern_input.dtg.h"
#include "substitutions/unlabelled/unlabelled_graph_pattern_subgraph_result.dtg.h"

namespace FlexFlow {

Expand All @@ -26,8 +27,8 @@ std::vector<PatternValue>
std::vector<PatternValue>
get_outputs_from_pattern_node(UnlabelledGraphPattern const &, PatternNode const &);

UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &,
std::unordered_set<PatternNode> const &);
UnlabelledGraphPatternSubgraphResult get_subgraph(UnlabelledGraphPattern const &,
std::unordered_set<PatternNode> const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace = "FlexFlow"
name = "UnlabelledGraphPatternSubgraphResult"
features = [ ]

includes = [
"substitutions/unlabelled/unlabelled_graph_pattern.dtg.h",
"substitutions/unlabelled/pattern_value.dtg.h",
"substitutions/unlabelled/pattern_input.dtg.h",
"utils/bidict/bidict.h",
]

[[fields]]
name = "subpattern"
type = "::FlexFlow::UnlabelledGraphPattern"

[[fields]]
name = "full_pattern_values_to_subpattern_inputs"
type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>"
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ std::vector<UnlabelledDataflowGraphPatternMatch>
}
} else {
PatternSplit split = find_even_split(pattern);
auto subpatterns = apply_split(pattern, split);
PatternSplitResult subpatterns = apply_split(pattern, split);
std::vector<UnlabelledDataflowGraphPatternMatch> prefix_matches =
find_pattern_matches(subpatterns.subpattern_1, graph, additional_criterion);
std::vector<UnlabelledDataflowGraphPatternMatch> postfix_matches =
Expand All @@ -87,7 +87,8 @@ std::vector<UnlabelledDataflowGraphPatternMatch>
std::optional<UnlabelledDataflowGraphPatternMatch> unsplit =
merge_unlabelled_dataflow_graph_pattern_matches(prefix_match,
postfix_match,
subpatterns.subpattern_1_outputs_to_subpattern_2_inputs);
subpatterns.full_pattern_values_to_subpattern_1_inputs,
subpatterns.full_pattern_values_to_subpattern_2_inputs);
if (unsplit.has_value() && unlabelled_pattern_does_match(pattern, graph, unsplit.value(), additional_criterion)) {
matches.push_back(unsplit.value());
}
Expand Down
21 changes: 17 additions & 4 deletions lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ OpenDataflowSubgraphResult subgraph_matched(OpenDataflowGraphView const &g,
// return true;
// }

struct ConcreteFromPattern {
ConcreteFromPattern(UnlabelledDataflowGraphPatternMatch const &match,
struct SubgraphConcreteFromPattern {
SubgraphConcreteFromPattern(UnlabelledDataflowGraphPatternMatch const &match,
bidict<OpenDataflowValue, DataflowGraphInput> const &full_graph_values_to_subgraph_inputs)
: match(match), full_graph_values_to_subgraph_inputs(full_graph_values_to_subgraph_inputs)
{ }
Expand Down Expand Up @@ -124,7 +124,7 @@ bool pattern_matches_subgraph_under(UnlabelledGraphPattern const &pattern,
bidict<OpenDataflowValue, DataflowGraphInput> const &full_graph_values_to_subgraph_inputs,
UnlabelledDataflowGraphPatternMatch const &match,
MatchAdditionalCriterion const &additional_criterion) {
ConcreteFromPattern concrete_from_pattern{match, full_graph_values_to_subgraph_inputs};
SubgraphConcreteFromPattern concrete_from_pattern{match, full_graph_values_to_subgraph_inputs};

std::unordered_set<Node> concrete_nodes = get_nodes(subgraph);
std::unordered_set<Node> concrete_nodes_from_match = transform(get_nodes(pattern), concrete_from_pattern);
Expand Down Expand Up @@ -174,7 +174,20 @@ bool unlabelled_pattern_does_match(
assert (keys(match.node_assignment) == get_nodes(pattern));
assert (keys(match.node_assignment.reversed()) == get_nodes(matched_subgraph));

return pattern_matches_subgraph_under(pattern, matched_subgraph, subgraph_result.full_graph_values_to_subgraph_inputs, match, additional_criterion);
MatchAdditionalCriterion through_subgraph_operation = MatchAdditionalCriterion{
additional_criterion.node_criterion,
[&](PatternValue const &pv, OpenDataflowValue const &v) {
return v.visit<bool>(overload {
[&](DataflowOutput const &) { return additional_criterion.value_criterion(pv, v); },
[&](DataflowGraphInput const &subgraph_input) {
OpenDataflowValue full_graph_value = subgraph_result.full_graph_values_to_subgraph_inputs.at_r(subgraph_input);
return additional_criterion.value_criterion(pv, full_graph_value);
}
});
},
};

return pattern_matches_subgraph_under(pattern, matched_subgraph, subgraph_result.full_graph_values_to_subgraph_inputs, match, through_subgraph_operation);
}

} // namespace FlexFlow
20 changes: 7 additions & 13 deletions lib/substitutions/src/substitutions/unlabelled/pattern_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,14 @@ PatternSplit find_even_split(UnlabelledGraphPattern const &pattern) {

PatternSplitResult
apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) {
OpenDataflowSubgraphResult raw_second_subgraph_result = get_subgraph(p.raw_graph, transform(s.second, [](PatternNode const &pn) { return pn.raw_node; }));

bidict<PatternValue, PatternInput> subpattern_1_outputs_to_subpattern_2_inputs;
for (auto const &kv : raw_second_subgraph_result.full_graph_values_to_subgraph_inputs) {
OpenDataflowValue open_dataflow_value = kv.first;
DataflowGraphInput dataflow_graph_input = kv.second;
subpattern_1_outputs_to_subpattern_2_inputs.equate(
pattern_value_from_raw_open_dataflow_value(open_dataflow_value), PatternInput{dataflow_graph_input});
}

UnlabelledGraphPatternSubgraphResult first_subgraph_result = get_subgraph(p, s.first);
UnlabelledGraphPatternSubgraphResult second_subgraph_result = get_subgraph(p, s.second);

return PatternSplitResult{
get_subgraph(p, s.first),
UnlabelledGraphPattern{raw_second_subgraph_result.graph},
subpattern_1_outputs_to_subpattern_2_inputs,
first_subgraph_result.subpattern,
second_subgraph_result.subpattern,
first_subgraph_result.full_pattern_values_to_subpattern_inputs,
second_subgraph_result.full_pattern_values_to_subpattern_inputs
};
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h"
#include "utils/containers.h"
#include "utils/containers/filtermap_keys.h"
#include "utils/bidict/try_merge_nondisjoint_bidicts.h"
#include "utils/containers/try_merge_nondisjoint_unordered_maps.h"

namespace FlexFlow {

Expand All @@ -10,34 +13,11 @@ UnlabelledDataflowGraphPatternMatch empty_unlabelled_pattern_match() {
};
}

template <typename L, typename R>
std::optional<bidict<L, R>> try_merge_nondisjoint_bidicts(bidict<L, R> const &d1,
bidict<L, R> const &d2) {
bidict<L, R> result;
for (L const &l : set_union(keys(d1), keys(d2))) {
if (d1.contains_l(l) && d2.contains_l(l)) {
if (d1.at_l(l) == d2.at_l(l)) {
result.equate(l, d1.at_l(l));
} else {
return std::nullopt;
}
} else if (d1.contains_l(l)) {
result.equate(l, d1.at_l(l));
} else {
assert (d2.contains_l(l));

result.equate(l, d2.at_l(l));
}
}

return result;
}


std::optional<UnlabelledDataflowGraphPatternMatch>
merge_unlabelled_dataflow_graph_pattern_matches(UnlabelledDataflowGraphPatternMatch const &subpattern_1,
UnlabelledDataflowGraphPatternMatch const &subpattern_2,
bidict<PatternValue, PatternInput> const &outputs_of_1_to_inputs_of_2) {
bidict<PatternValue, PatternInput> const &merged_graph_values_to_inputs_of_1,
bidict<PatternValue, PatternInput> const &merged_graph_values_to_inputs_of_2) {
bidict<PatternNode, Node> merged_node_assignment = ({
std::optional<bidict<PatternNode, Node>> result = try_merge_nondisjoint_bidicts(
subpattern_1.node_assignment, subpattern_2.node_assignment);
Expand All @@ -47,11 +27,37 @@ std::optional<UnlabelledDataflowGraphPatternMatch>
result.value();
});

assert (all_of(keys(subpattern_2.input_assignment), [&](PatternInput const &i) { return outputs_of_1_to_inputs_of_2.contains_r(i); }));
std::unordered_map<PatternInput, OpenDataflowValue> merged_input_assignment = ({
std::unordered_map<PatternValue, OpenDataflowValue> lifted_input_assignment_1 = map_keys(
subpattern_1.input_assignment,
[&](PatternInput const &pi1) {
return merged_graph_values_to_inputs_of_1.at_r(pi1);
}
);
std::unordered_map<PatternValue, OpenDataflowValue> lifted_input_assignment_2 = map_keys(
subpattern_2.input_assignment,
[&](PatternInput const &pi2) {
return merged_graph_values_to_inputs_of_2.at_r(pi2);
}
);
std::optional<std::unordered_map<PatternValue, OpenDataflowValue>> merged = try_merge_nondisjoint_unordered_maps(
lifted_input_assignment_1, lifted_input_assignment_2);
if (!merged.has_value()) {
return std::nullopt;
}
filtermap_keys(merged.value(),
[](PatternValue const &v) -> std::optional<PatternInput> {
if (v.has<PatternInput>()) {
return v.get<PatternInput>();
} else {
return std::nullopt;
}
});
});

return UnlabelledDataflowGraphPatternMatch{
merged_node_assignment,
subpattern_1.input_assignment,
merged_input_assignment,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,22 @@ std::vector<PatternValue>
[](DataflowOutput const &o) { return pattern_value_from_raw_open_dataflow_value(OpenDataflowValue{o}); });
}

UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p,
std::unordered_set<PatternNode> const &n) {
OpenDataflowGraphView raw_subgraph =
get_subgraph(p.raw_graph, transform(n, [](PatternNode const &pn) { return pn.raw_node; })).graph;
return UnlabelledGraphPattern{
raw_subgraph,
UnlabelledGraphPatternSubgraphResult get_subgraph(UnlabelledGraphPattern const &p,
std::unordered_set<PatternNode> const &n) {
OpenDataflowSubgraphResult raw_result =
get_subgraph(p.raw_graph, transform(n, [](PatternNode const &pn) { return pn.raw_node; }));
bidict<PatternValue, PatternInput> full_pattern_values_to_subpattern_inputs = transform(
raw_result.full_graph_values_to_subgraph_inputs,
[](OpenDataflowValue const &v, DataflowGraphInput const &i) {
return std::make_pair(
pattern_value_from_raw_open_dataflow_value(v),
PatternInput{i}
);
}
);
return UnlabelledGraphPatternSubgraphResult{
UnlabelledGraphPattern{raw_result.graph},
full_pattern_values_to_subpattern_inputs ,
};
}

Expand Down
56 changes: 28 additions & 28 deletions lib/substitutions/test/src/substitutions/pcg_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,33 +101,33 @@ TEST_SUITE(FF_TEST_SUITE) {
PCGPattern pattern = PCGPattern{g};

std::unordered_set<UnlabelledDataflowGraphPatternMatch> result = without_order(find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg)));
//
// UnlabelledDataflowGraphPatternMatch match1 = UnlabelledDataflowGraphPatternMatch{
// bidict<PatternNode, Node>{
// {op_pattern_1_node, x_matmul.raw_graph_node},
// {op_pattern_2_node, y_matmul.raw_graph_node},
// },
// bidict<PatternInput, OpenDataflowValue>{
// {PatternInput{pt_a}, OpenDataflowValue{a_tensor.raw_graph_output}},
// {PatternInput{pt_b}, OpenDataflowValue{x_weights.raw_graph_output}},
// {PatternInput{pt_c}, OpenDataflowValue{y_weights.raw_graph_output}},
// }
// };
//
// UnlabelledDataflowGraphPatternMatch match2 = UnlabelledDataflowGraphPatternMatch{
// bidict<PatternNode, Node>{
// {op_pattern_1_node, y_matmul.raw_graph_node},
// {op_pattern_2_node, x_matmul.raw_graph_node},
// },
// bidict<PatternInput, OpenDataflowValue>{
// {PatternInput{pt_a}, OpenDataflowValue{a_tensor.raw_graph_output}},
// {PatternInput{pt_b}, OpenDataflowValue{y_weights.raw_graph_output}},
// {PatternInput{pt_c}, OpenDataflowValue{x_weights.raw_graph_output}},
// }
// };
//
// std::unordered_set<UnlabelledDataflowGraphPatternMatch> correct = {match1, match2};
//
// CHECK(result == correct);

UnlabelledDataflowGraphPatternMatch match1 = UnlabelledDataflowGraphPatternMatch{
bidict<PatternNode, Node>{
{op_pattern_1_node, x_matmul.raw_graph_node},
{op_pattern_2_node, y_matmul.raw_graph_node},
},
bidict<PatternInput, OpenDataflowValue>{
{PatternInput{pt_a}, OpenDataflowValue{a_tensor.raw_graph_output}},
{PatternInput{pt_b}, OpenDataflowValue{x_weights.raw_graph_output}},
{PatternInput{pt_c}, OpenDataflowValue{y_weights.raw_graph_output}},
}
};

UnlabelledDataflowGraphPatternMatch match2 = UnlabelledDataflowGraphPatternMatch{
bidict<PatternNode, Node>{
{op_pattern_1_node, y_matmul.raw_graph_node},
{op_pattern_2_node, x_matmul.raw_graph_node},
},
bidict<PatternInput, OpenDataflowValue>{
{PatternInput{pt_a}, OpenDataflowValue{a_tensor.raw_graph_output}},
{PatternInput{pt_b}, OpenDataflowValue{y_weights.raw_graph_output}},
{PatternInput{pt_c}, OpenDataflowValue{x_weights.raw_graph_output}},
}
};

std::unordered_set<UnlabelledDataflowGraphPatternMatch> correct = {match1, match2};

CHECK(result == correct);
}
}
Loading