diff --git a/include/onnxruntime/core/optimizer/rewrite_rule.h b/include/onnxruntime/core/optimizer/rewrite_rule.h index f481439e5ff00..fa8583bb1c922 100644 --- a/include/onnxruntime/core/optimizer/rewrite_rule.h +++ b/include/onnxruntime/core/optimizer/rewrite_rule.h @@ -66,7 +66,7 @@ class RewriteRule { @param[in] node The Node to apply the rewrite to. @param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application. @returns Status indicating success or providing error information */ - common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { + common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK(); } @@ -79,11 +79,11 @@ class RewriteRule { evaluated if this condition function returns true. This can include a more complex pattern matching (conditions on the ascending or descending nodes of the node for which this rule was triggered) or some other properties of the nodes. */ - virtual bool SatisfyCondition(const Graph& graph, const Node& node) = 0; + virtual bool SatisfyCondition(const Graph& graph, const Node& node) const = 0; /** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place. The return-value of node may be different from the input-value due to rewriting. The value of "rule_effect" indicates whether and how the graph was modified by the rule. */ - virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) = 0; + virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const = 0; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h b/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h index 4a3fe8159a4bb..97ce347873d8c 100644 --- a/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h +++ b/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h @@ -39,14 +39,14 @@ class RuleBasedGraphTransformer : public GraphTransformer { /** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type by this rule-based transformer. @returns a pointer to the vector containing all the registered rewrite rules. */ - const std::vector>* GetRewriteRulesForOpType(const std::string& op_type) const { + const std::vector>* GetRewriteRulesForOpType(const std::string& op_type) const { auto rules = op_type_to_rules_.find(op_type); return (rules != op_type_to_rules_.cend()) ? &rules->second : nullptr; } /** Gets the rewrite rules that are evaluated on all nodes irrespective of their op type. @returns a pointer to the vector containing all such rewrite rules or nullptr if no such rule. */ - const std::vector>* GetAnyOpRewriteRules() const { + const std::vector>* GetAnyOpRewriteRules() const { return &any_op_type_rules_; } @@ -62,16 +62,18 @@ class RuleBasedGraphTransformer : public GraphTransformer { applying rules on this node. @returns Status indicating success or providing error information. */ common::Status ApplyRulesOnNode(Graph& graph, Node& node, - const std::vector>& rules, + const std::vector>& rules, RewriteRule::RewriteRuleEffect& rule_effect) const; private: using RuleEffect = RewriteRule::RewriteRuleEffect; + // The list of unique pointers for all rules (so that rules can be registered for several op types). + std::vector> rules_; // Map that associates a node's op type with the vector of rules that are registered to be triggered for that node. - std::unordered_map>> op_type_to_rules_; + std::unordered_map>> op_type_to_rules_; // Rules that will be evaluated regardless of the op type of the node. - std::vector> any_op_type_rules_; + std::vector> any_op_type_rules_; // Performs a single top-down traversal of the graph and applies all registered rules. common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index 59a8c010138a2..d0284e93ce2a7 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -9,7 +9,7 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) { +Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const { auto& conv_node = node; const auto& add_node = *conv_node.OutputNodesBegin(); const auto& conv_inputs = conv_node.InputDefs(); @@ -107,7 +107,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie return Status::OK(); } -bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) { +bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/optimizer/conv_add_fusion.h b/onnxruntime/core/optimizer/conv_add_fusion.h index 3fe4e92b5abcf..7763e249bd118 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.h +++ b/onnxruntime/core/optimizer/conv_add_fusion.h @@ -23,9 +23,9 @@ class ConvAddFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) override; + bool SatisfyCondition(const Graph& graph, const Node& node) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index f13bf64eafa6e..1d22d7c02f93c 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -9,7 +9,7 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { +Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { auto& conv_node = node; const Node& bn_node = *conv_node.OutputNodesBegin(); @@ -142,7 +142,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff return Status::OK(); } -bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) { +bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.h b/onnxruntime/core/optimizer/conv_bn_fusion.h index e23095bfdf49c..cdce82035f032 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.h +++ b/onnxruntime/core/optimizer/conv_bn_fusion.h @@ -23,9 +23,9 @@ class ConvBNFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) override; + bool SatisfyCondition(const Graph& graph, const Node& node) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index dd27f0357ff39..0e5cbfc5d583d 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -9,7 +9,7 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { +Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { auto& conv_node = node; const auto& mul_node = *conv_node.OutputNodesBegin(); const auto& conv_inputs = conv_node.InputDefs(); @@ -105,7 +105,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef return Status::OK(); } -bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) { +bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1}) || node.GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.h b/onnxruntime/core/optimizer/conv_mul_fusion.h index 62a39b624570a..bb6a35bf7f01a 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.h +++ b/onnxruntime/core/optimizer/conv_mul_fusion.h @@ -22,9 +22,9 @@ class ConvMulFusion : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) override; + bool SatisfyCondition(const Graph& graph, const Node& node) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/dropout_elimination.cc b/onnxruntime/core/optimizer/dropout_elimination.cc index f08643f0297bc..ff2f1a3477e18 100644 --- a/onnxruntime/core/optimizer/dropout_elimination.cc +++ b/onnxruntime/core/optimizer/dropout_elimination.cc @@ -10,7 +10,7 @@ namespace onnxruntime { -Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { +Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { if (graph_utils::RemoveNode(graph, node)) { rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } @@ -18,7 +18,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule return Status::OK(); } -bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) { +bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) const { // We currently support elimination for Dropout operator v1, v6, v7, and v10. if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10})) { return false; diff --git a/onnxruntime/core/optimizer/dropout_elimination.h b/onnxruntime/core/optimizer/dropout_elimination.h index 2310eaa4366cb..e840767497e66 100644 --- a/onnxruntime/core/optimizer/dropout_elimination.h +++ b/onnxruntime/core/optimizer/dropout_elimination.h @@ -23,9 +23,9 @@ class EliminateDropout : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) override; + bool SatisfyCondition(const Graph& graph, const Node& node) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 9350554433822..593bd58c2eb56 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -11,6 +11,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/dropout_elimination.h" namespace onnxruntime { @@ -28,6 +29,7 @@ std::vector> GenerateRewriteRules(TransformerLevel rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: diff --git a/onnxruntime/core/optimizer/identity_elimination.cc b/onnxruntime/core/optimizer/identity_elimination.cc index 236b98f5887d5..09df7fab42942 100644 --- a/onnxruntime/core/optimizer/identity_elimination.cc +++ b/onnxruntime/core/optimizer/identity_elimination.cc @@ -10,7 +10,7 @@ namespace onnxruntime { -Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { +Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { if (graph_utils::RemoveNode(graph, node)) { rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } @@ -18,7 +18,7 @@ Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rul return Status::OK(); } -bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node) { +bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node) const { return graph_utils::IsSingleInSingleOutNode(node) && !graph.IsNodeOutputsInGraphOutputs(node); } diff --git a/onnxruntime/core/optimizer/identity_elimination.h b/onnxruntime/core/optimizer/identity_elimination.h index b90d2164e01d8..55d8c2d8fa33f 100644 --- a/onnxruntime/core/optimizer/identity_elimination.h +++ b/onnxruntime/core/optimizer/identity_elimination.h @@ -23,9 +23,9 @@ class EliminateIdentity : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) override; + bool SatisfyCondition(const Graph& graph, const Node& node) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; }; // namespace onnxruntime } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/rule_based_graph_transformer.cc b/onnxruntime/core/optimizer/rule_based_graph_transformer.cc index 7bafbed87b119..b25e4505a8b4e 100644 --- a/onnxruntime/core/optimizer/rule_based_graph_transformer.cc +++ b/onnxruntime/core/optimizer/rule_based_graph_transformer.cc @@ -13,19 +13,23 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr rule) { auto op_types = rule->TargetOpTypes(); // If the target op types are empty, this rule will be evaluated for all op types. if (op_types.empty()) { - any_op_type_rules_.push_back(std::move(rule)); + any_op_type_rules_.push_back(*rule); } else { std::for_each(op_types.cbegin(), op_types.cend(), - [&](const auto& op_type) { op_type_to_rules_[op_type].push_back(std::move(rule)); }); + [&](const auto& op_type) { op_type_to_rules_[op_type].push_back(*rule); }); } + + // Save unique pointer at the rules_ list. + rules_.push_back(std::move(rule)); + return Status::OK(); } Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node, - const std::vector>& rules, + const std::vector>& rules, RuleEffect& rule_effect) const { - for (const auto& rule : rules) { - ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, rule_effect)); + for (const RewriteRule& rule : rules) { + ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect)); // If the current node was removed as a result of a rule, stop rule application for that node. if (rule_effect == RuleEffect::kRemovedCurrentNode) { break; @@ -56,7 +60,7 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr // First apply rewrite rules that are registered for the op type of the current node; then apply rules that are // registered to be applied regardless of the op type; then recursively apply rules to subgraphs (if any). // Stop further rule application for the current node, if the node gets removed by a rule. - const std::vector>* rules = nullptr; + const std::vector>* rules = nullptr; rules = GetRewriteRulesForOpType(node->OpType()); if (rules) { @@ -84,9 +88,7 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr } size_t RuleBasedGraphTransformer::RulesCount() const { - return any_op_type_rules_.size() + - std::accumulate(op_type_to_rules_.cbegin(), op_type_to_rules_.cend(), size_t(0), - [](size_t sum, const auto& rules) { return sum + rules.second.size(); }); + return rules_.size(); } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index 3a42b6170d75a..65ebd81811218 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -8,7 +8,7 @@ namespace onnxruntime { -Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { +Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { if (graph_utils::RemoveNode(graph, node)) { rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } @@ -16,7 +16,7 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_e return Status::OK(); } -bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) { +bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) const { // We currently support elimination for Slice operator v1. // TODO Extend to support Slice operator v10, which includes "steps" and all attributes are now given as inputs. if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1})) { diff --git a/onnxruntime/core/optimizer/slice_elimination.h b/onnxruntime/core/optimizer/slice_elimination.h index 28d689c558097..8a9ed2947417a 100644 --- a/onnxruntime/core/optimizer/slice_elimination.h +++ b/onnxruntime/core/optimizer/slice_elimination.h @@ -23,9 +23,9 @@ class EliminateSlice : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) override; + bool SatisfyCondition(const Graph& graph, const Node& node) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/unsqueeze_elimination.cc b/onnxruntime/core/optimizer/unsqueeze_elimination.cc index 549d415bf244e..c35d078546fa6 100644 --- a/onnxruntime/core/optimizer/unsqueeze_elimination.cc +++ b/onnxruntime/core/optimizer/unsqueeze_elimination.cc @@ -10,7 +10,7 @@ using namespace ::onnxruntime::common; namespace onnxruntime { -Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { +Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { // Get "axes" attribute. const ONNX_NAMESPACE::AttributeProto* attr = graph_utils::GetNodeAttribute(node, "axes"); if (attr == nullptr || attr->type() != AttributeProto_AttributeType_INTS) { @@ -74,7 +74,7 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& return Status::OK(); } // namespace onnxruntime -bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node) { +bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node) const { // Attempt to remove an Unsqueeze operator only if it gets an initializer as input. return node.GetInputEdgesCount() == 0 && !graph.IsNodeOutputsInGraphOutputs(node); diff --git a/onnxruntime/core/optimizer/unsqueeze_elimination.h b/onnxruntime/core/optimizer/unsqueeze_elimination.h index e8e4dad40057f..3150513c13642 100644 --- a/onnxruntime/core/optimizer/unsqueeze_elimination.h +++ b/onnxruntime/core/optimizer/unsqueeze_elimination.h @@ -23,9 +23,9 @@ class UnsqueezeElimination : public RewriteRule { } private: - bool SatisfyCondition(const Graph& graph, const Node& node) override; + bool SatisfyCondition(const Graph& graph, const Node& node) const override; - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/dummy_graph_transformer.h b/onnxruntime/test/optimizer/dummy_graph_transformer.h index 1bff4af37fcf3..8116d0ba5f18a 100644 --- a/onnxruntime/test/optimizer/dummy_graph_transformer.h +++ b/onnxruntime/test/optimizer/dummy_graph_transformer.h @@ -41,13 +41,13 @@ class DummyRewriteRule : public RewriteRule { } private: - bool rewrite_rule_invoked_; + mutable bool rewrite_rule_invoked_; - bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/) override { + bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/) const override { return true; } - Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) override { + Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) const override { rewrite_rule_invoked_ = true; return Status::OK(); }