Skip to content

Commit

Permalink
Fix when rewrite rule gets registered to multiple op types; update co…
Browse files Browse the repository at this point in the history
…nstness of rule methods; enable dropout elimination (#1098)
  • Loading branch information
kkaranasos authored May 24, 2019
1 parent 9129a65 commit ee62179
Show file tree
Hide file tree
Showing 19 changed files with 54 additions and 48 deletions.
6 changes: 3 additions & 3 deletions include/onnxruntime/core/optimizer/rewrite_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class RewriteRule {
@param[in] node The Node to apply the rewrite to.
@param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
@returns Status indicating success or providing error information */
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK();
}

Expand All @@ -79,11 +79,11 @@ class RewriteRule {
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
of the nodes. */
virtual bool SatisfyCondition(const Graph& graph, const Node& node) = 0;
virtual bool SatisfyCondition(const Graph& graph, const Node& node) const = 0;

/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
The return-value of node may be different from the input-value due to rewriting.
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) = 0;
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const = 0;
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ class RuleBasedGraphTransformer : public GraphTransformer {
/** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type
by this rule-based transformer.
@returns a pointer to the vector containing all the registered rewrite rules. */
const std::vector<std::unique_ptr<RewriteRule>>* GetRewriteRulesForOpType(const std::string& op_type) const {
const std::vector<std::reference_wrapper<const RewriteRule>>* GetRewriteRulesForOpType(const std::string& op_type) const {
auto rules = op_type_to_rules_.find(op_type);
return (rules != op_type_to_rules_.cend()) ? &rules->second : nullptr;
}

/** Gets the rewrite rules that are evaluated on all nodes irrespective of their op type.
@returns a pointer to the vector containing all such rewrite rules or nullptr if no such rule. */
const std::vector<std::unique_ptr<RewriteRule>>* GetAnyOpRewriteRules() const {
const std::vector<std::reference_wrapper<const RewriteRule>>* GetAnyOpRewriteRules() const {
return &any_op_type_rules_;
}

Expand All @@ -62,16 +62,18 @@ class RuleBasedGraphTransformer : public GraphTransformer {
applying rules on this node.
@returns Status indicating success or providing error information. */
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::unique_ptr<RewriteRule>>& rules,
const std::vector<std::reference_wrapper<const RewriteRule>>& rules,
RewriteRule::RewriteRuleEffect& rule_effect) const;

private:
using RuleEffect = RewriteRule::RewriteRuleEffect;

// The list of unique pointers for all rules (so that rules can be registered for several op types).
std::vector<std::unique_ptr<RewriteRule>> rules_;
// Map that associates a node's op type with the vector of rules that are registered to be triggered for that node.
std::unordered_map<std::string, std::vector<std::unique_ptr<RewriteRule>>> op_type_to_rules_;
std::unordered_map<std::string, std::vector<std::reference_wrapper<const RewriteRule>>> op_type_to_rules_;
// Rules that will be evaluated regardless of the op type of the node.
std::vector<std::unique_ptr<RewriteRule>> any_op_type_rules_;
std::vector<std::reference_wrapper<const RewriteRule>> any_op_type_rules_;

// Performs a single top-down traversal of the graph and applies all registered rules.
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {

Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) {
Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) const {
auto& conv_node = node;
const auto& add_node = *conv_node.OutputNodesBegin();
const auto& conv_inputs = conv_node.InputDefs();
Expand Down Expand Up @@ -107,7 +107,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
return Status::OK();
}

bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) {
bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1}) ||
node.GetOutputEdgesCount() != 1) {
return false;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_add_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class ConvAddFusion : public RewriteRule {
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool SatisfyCondition(const Graph& graph, const Node& node) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};

} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {

Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
auto& conv_node = node;
const Node& bn_node = *conv_node.OutputNodesBegin();

Expand Down Expand Up @@ -142,7 +142,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
return Status::OK();
}

bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) {
bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1}) ||
node.GetOutputEdgesCount() != 1) {
return false;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_bn_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class ConvBNFusion : public RewriteRule {
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool SatisfyCondition(const Graph& graph, const Node& node) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};

} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_mul_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {

Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
auto& conv_node = node;
const auto& mul_node = *conv_node.OutputNodesBegin();
const auto& conv_inputs = conv_node.InputDefs();
Expand Down Expand Up @@ -105,7 +105,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
return Status::OK();
}

bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) {
bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1}) ||
node.GetOutputEdgesCount() != 1) {
return false;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_mul_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ class ConvMulFusion : public RewriteRule {
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool SatisfyCondition(const Graph& graph, const Node& node) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};

} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/dropout_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

namespace onnxruntime {

Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}

return Status::OK();
}

bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) {
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node) const {
// We currently support elimination for Dropout operator v1, v6, v7, and v10.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10})) {
return false;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/dropout_elimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class EliminateDropout : public RewriteRule {
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool SatisfyCondition(const Graph& graph, const Node& node) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};

} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/dropout_elimination.h"

namespace onnxruntime {

Expand All @@ -27,6 +28,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
case TransformerLevel::Level1:
rules.push_back(std::make_unique<EliminateIdentity>());
rules.push_back(std::make_unique<EliminateSlice>());
rules.push_back(std::make_unique<EliminateDropout>());
break;

case TransformerLevel::Level2:
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/identity_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

namespace onnxruntime {

Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}

return Status::OK();
}

bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node) {
bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node) const {
return graph_utils::IsSingleInSingleOutNode(node) &&
!graph.IsNodeOutputsInGraphOutputs(node);
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/identity_elimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class EliminateIdentity : public RewriteRule {
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool SatisfyCondition(const Graph& graph, const Node& node) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
}; // namespace onnxruntime

} // namespace onnxruntime
20 changes: 11 additions & 9 deletions onnxruntime/core/optimizer/rule_based_graph_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr<RewriteRule> rule) {
auto op_types = rule->TargetOpTypes();
// If the target op types are empty, this rule will be evaluated for all op types.
if (op_types.empty()) {
any_op_type_rules_.push_back(std::move(rule));
any_op_type_rules_.push_back(*rule);
} else {
std::for_each(op_types.cbegin(), op_types.cend(),
[&](const auto& op_type) { op_type_to_rules_[op_type].push_back(std::move(rule)); });
[&](const auto& op_type) { op_type_to_rules_[op_type].push_back(*rule); });
}

// Save unique pointer at the rules_ list.
rules_.push_back(std::move(rule));

return Status::OK();
}

Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::unique_ptr<RewriteRule>>& rules,
const std::vector<std::reference_wrapper<const RewriteRule>>& rules,
RuleEffect& rule_effect) const {
for (const auto& rule : rules) {
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, rule_effect));
for (const RewriteRule& rule : rules) {
ORT_RETURN_IF_ERROR(rule.CheckConditionAndApply(graph, node, rule_effect));
// If the current node was removed as a result of a rule, stop rule application for that node.
if (rule_effect == RuleEffect::kRemovedCurrentNode) {
break;
Expand Down Expand Up @@ -56,7 +60,7 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
// First apply rewrite rules that are registered for the op type of the current node; then apply rules that are
// registered to be applied regardless of the op type; then recursively apply rules to subgraphs (if any).
// Stop further rule application for the current node, if the node gets removed by a rule.
const std::vector<std::unique_ptr<RewriteRule>>* rules = nullptr;
const std::vector<std::reference_wrapper<const RewriteRule>>* rules = nullptr;

rules = GetRewriteRulesForOpType(node->OpType());
if (rules) {
Expand Down Expand Up @@ -84,9 +88,7 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
}

size_t RuleBasedGraphTransformer::RulesCount() const {
return any_op_type_rules_.size() +
std::accumulate(op_type_to_rules_.cbegin(), op_type_to_rules_.cend(), size_t(0),
[](size_t sum, const auto& rules) { return sum + rules.second.size(); });
return rules_.size();
}

} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/slice_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

namespace onnxruntime {

Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}

return Status::OK();
}

bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) {
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) const {
// We currently support elimination for Slice operator v1.
// TODO Extend to support Slice operator v10, which includes "steps" and all attributes are now given as inputs.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1})) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/slice_elimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class EliminateSlice : public RewriteRule {
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool SatisfyCondition(const Graph& graph, const Node& node) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};

} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/unsqueeze_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using namespace ::onnxruntime::common;

namespace onnxruntime {

Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
// Get "axes" attribute.
const ONNX_NAMESPACE::AttributeProto* attr = graph_utils::GetNodeAttribute(node, "axes");
if (attr == nullptr || attr->type() != AttributeProto_AttributeType_INTS) {
Expand Down Expand Up @@ -74,7 +74,7 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect&
return Status::OK();
} // namespace onnxruntime

bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node) {
bool UnsqueezeElimination::SatisfyCondition(const Graph& graph, const Node& node) const {
// Attempt to remove an Unsqueeze operator only if it gets an initializer as input.
return node.GetInputEdgesCount() == 0 &&
!graph.IsNodeOutputsInGraphOutputs(node);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/unsqueeze_elimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class UnsqueezeElimination : public RewriteRule {
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool SatisfyCondition(const Graph& graph, const Node& node) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};

} // namespace onnxruntime
6 changes: 3 additions & 3 deletions onnxruntime/test/optimizer/dummy_graph_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ class DummyRewriteRule : public RewriteRule {
}

private:
bool rewrite_rule_invoked_;
mutable bool rewrite_rule_invoked_;

bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/) override {
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/) const override {
return true;
}

Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) override {
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) const override {
rewrite_rule_invoked_ = true;
return Status::OK();
}
Expand Down

0 comments on commit ee62179

Please sign in to comment.