diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 8b34b7619840..ff02e50eb5fb 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -167,6 +167,19 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu. Note tha out = relay.nn.relu(tuple_get_item_node) pat.match(out) +If we have a pattern that crosses a function boundary, we might want to match the Function itself + + +.. code-block:: python + + def test_match_func(): + x = relay.var("x") + y = relay.var("y") + wc1 = wildcard() + wc2 = wildcard() + func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2) + assert func_pattern.match(relay.Function([x, y], x + y)) + The next example is matching a constant node regarding its values. This is useful to check if a specific parameter in a subgraph has been bound or not. @@ -283,6 +296,7 @@ The high level design is to introduce a language of patterns for now we propose | is_tuple_get_item(pattern, index = None) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) + | FunctionPattern(params, body) The above language then provides a matching interface with both can select sub-graphs as well as verify that the graph does match the pattern. @@ -332,6 +346,11 @@ Domination Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern. +Function Pattern +**************** + +Match a Function with a body and parameters + Applications ============ diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 11ac7e39f4a3..909a4fe44eb1 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -148,34 +148,9 @@ class CallPatternNode : public DFPatternNode { /*! \brief The arguments(inputs) of the call */ tvm::Array args; - /*! \brief The additional attributes */ - Attrs attrs; - - /*! - * \brief The type arguments passed to polymorphic(template) function. - * - * This is the advance feature that is only used when the function is - * polymorphic. It is safe to be ignored in most cases. For example, in the - * following code, the type_args of addone call is [int]. - * - * \code - * - * template - * T addone(T a) { return a + 1; } - * - * void main() { - * int x = addone(10); - * } - * - * \endcode - */ - tvm::Array type_args; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); v->Visit("args", &args); - v->Visit("attrs", &attrs); - v->Visit("type_args", &type_args); } static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern"; @@ -184,10 +159,52 @@ class CallPatternNode : public DFPatternNode { class CallPattern : public DFPattern { public: - TVM_DLL CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args); + TVM_DLL CallPattern(DFPattern op, Array args); TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); }; +/*! + * \brief Relay Function container + * \sa Function + */ +class FunctionPatternNode : public DFPatternNode { + public: + /*! \brief Function parameters */ + tvm::Array params; + /*! + * \brief + * The expression which represents the computation of the function, + * the expression may reference the parameters, and the type of it + * or sub-expressions may reference the type variables. + */ + DFPattern body; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.FunctionPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to FunctionNode. + * \sa FunctionNode + */ +class FunctionPattern : public DFPattern { + public: + /*! + * \brief Constructor + * \param params The parameters of the function. + * \param body The body of the function. + */ + TVM_DLL FunctionPattern(tvm::Array params, DFPattern body); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode); +}; + /*! \brief Tuple of multiple Exprs */ class TuplePattern; /*! \brief Tuple container */ diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index 364daac81cc8..f04977b86ccb 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -87,6 +87,7 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -112,6 +113,7 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); @@ -138,6 +140,7 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const DataTypePatternNode* op) override; void VisitDFPattern_(const DominatorPatternNode* op) override; void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const FunctionPatternNode* op) override; void VisitDFPattern_(const ShapePatternNode* op) override; void VisitDFPattern_(const TupleGetItemPatternNode* op) override; void VisitDFPattern_(const TuplePatternNode* op) override; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 7178bff2c114..233c696fd716 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -504,24 +504,36 @@ class CallPattern(DFPattern): args: List[realy.dataflow_pattern.DFPattern] The arguments to the call. - attrs: Optional[tvm.ir.attrs.Attrs] - Attributes to the call, can be None - - type_args: Optional[List[tvm.ir.type.Type]] - The additional type arguments, this is only - used in advanced usecase of template functions. """ def __init__( self, op: "DFPattern", args: List["DFPattern"], - attrs: Optional[tvm.ir.attrs.Attrs] = None, - type_args: Optional[List[tvm.ir.type.Type]] = None, ): - if not type_args: - type_args = [] - self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args) + self.__init_handle_by_constructor__(ffi.CallPattern, op, args) + + +@register_df_node +class FunctionPattern(DFPattern): + """A pattern matching a function node in Relay. + + Parameters + ---------- + params: List[realy.dataflow_pattern.DFPattern] + The parameters to the Function. + + body: realy.dataflow_pattern.DFPattern + The body fo the Function + + """ + + def __init__( + self, + params: List["DFPattern"], + body: "DFPattern", + ): + self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body) @register_df_node diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 44b87633d208..c5cc3dd17429 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -54,6 +54,7 @@ class DFPatternMatcher : public DFPatternFunctorargs[1], "divide"))) { bool out = false; for (size_t arg_id = 0; arg_id < 2; ++arg_id) { - auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs, - op->type_args); - auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, - arg_node->attrs, arg_node->type_args); + auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}); + auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}); out = VisitDFPattern(mul, expr); if (out) { return true; @@ -286,10 +285,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") && (is_expr_op(call_node->args[0], "multiply") || is_expr_op(call_node->args[1], "multiply"))) { - auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, - op->attrs, op->type_args); - auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs, - arg_node->type_args); + auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}); + auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}); return VisitDFPattern(div, expr); } } @@ -356,6 +353,26 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex return StructuralEqual()(op->expr, expr); } +bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* func = expr.as()) { + matches = true; + size_t i = 0; + if (op->params.size() == func->params.size()) { + while (matches && i < op->params.size()) { + matches &= VisitDFPattern(op->params[i], func->params[i]); + ++i; + } + } else { + matches = false; + } + if (matches) { + matches &= VisitDFPattern(op->body, func->body); + } + } + return matches; +} + bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { bool matches = false; if (const auto* tuple_get_item_node = expr.as()) { @@ -601,6 +618,7 @@ class PatternGrouper { // Get fuzzy patterns std::unordered_set fuzzy_matches; for (auto node : pattern_graph_.topological_order_) { + // Don't treat fuzzy Dominator patterns input variables for partition if (auto op = node->ref_.as()) { for (auto fuzzy_op : {op->parent, op->path}) { for (auto match : node_map[fuzzy_op]) { @@ -608,6 +626,14 @@ class PatternGrouper { } } } + // Don't treat Function params as input variables for partition + if (auto op = node->ref_.as()) { + for (auto fuzzy_op : op->params) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } + } + } } // Create input variables diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 4664e5fc8168..46c53c8bd96c 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -81,27 +81,41 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "ConstantPattern()"; }); -CallPattern::CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args) { +CallPattern::CallPattern(DFPattern op, Array args) { ObjectPtr n = make_object(); n->op = std::move(op); n->args = std::move(args); - n->attrs = std::move(attrs); - n->type_args = std::move(type_args); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(CallPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern") - .set_body_typed([](DFPattern op, Array args, Attrs attrs, Array type_args) { - return CallPattern(op, args, attrs, type_args); - }); + .set_body_typed([](DFPattern op, Array args) { return CallPattern(op, args); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs - << ", " << node->type_args << ")"; + p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")"; + }); + +FunctionPattern::FunctionPattern(Array params, DFPattern body) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + data_ = std::move(n); +} +TVM_REGISTER_NODE_TYPE(FunctionPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern") + .set_body_typed([](Array params, DFPattern body) { + return FunctionPattern(params, body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")"; }); TuplePattern::TuplePattern(tvm::Array fields) { diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index 7e9f828c8aa8..aaa4f84b3254 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -62,6 +62,13 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) { void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) { + for (auto param : op->params) { + VisitDFPattern(param); + } + VisitDFPattern(op->body); +} + void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); } void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 456bf02a0611..4ba053c429de 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -261,6 +261,13 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override { + for (auto param : op->params) { + VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + } + VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + } + void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); } diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 079b86715a48..cb42ab09aae4 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -46,7 +46,7 @@ class SimplifyReshape { x_ = WildcardPattern(make_object()); auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); - pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {}); + pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_})}); } Expr callback(const Expr& pre, const Expr& post, const Map>& node_map) { diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index d4c169bc603e..d99e55b7c33f 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -62,6 +62,19 @@ def test_CallPattern(): assert isinstance(c.args[1], WildcardPattern) +def test_FunctionPattern(): + wc1 = wildcard() + wc2 = wildcard() + c = is_op("add")(wc1, wc2) + f = FunctionPattern([wc1, wc2], c) + assert isinstance(f, FunctionPattern) + assert isinstance(f.params[0], WildcardPattern) + assert isinstance(f.params[1], WildcardPattern) + assert isinstance(f.body, CallPattern) + assert isinstance(f.body.args[0], WildcardPattern) + assert isinstance(f.body.args[1], WildcardPattern) + + def test_TuplePattern(): wc1 = wildcard() wc2 = wildcard() @@ -167,6 +180,24 @@ def test_no_match_call(): assert not add_pattern.match(x - y) +def test_match_func(): + x = relay.var("x") + y = relay.var("y") + wc1 = wildcard() + wc2 = wildcard() + func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2) + assert func_pattern.match(relay.Function([x, y], x + y)) + + +def test_no_match_func(): + x = relay.var("x") + y = relay.var("y") + wc1 = wildcard() + wc2 = wildcard() + func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2) + assert not func_pattern.match(relay.Function([x, y], x - y)) + + def test_match_option(): x = relay.var("x") w = relay.var("w") @@ -1300,6 +1331,36 @@ def test_partition_option(): assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu)) +def test_partition_function(): + x = relay.var("x") + w = relay.var("w") + b = relay.var("b") + + x1 = relay.var("x1") + w1 = relay.var("w1") + + wc_x = wildcard() + wc_w = wildcard() + wc_b = wildcard() + wc_x1 = wildcard() + wc_w1 = wildcard() + + func_pattern = FunctionPattern([wc_x1, wc_w1], is_op("nn.conv2d")(wc_x1, wc_w1)) + pattern = func_pattern(wc_x, wc_w) + wc_b + + func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1)) + expr = func(x, w) + b + b + + x2 = relay.var("x2") + w2 = relay.var("w2") + b2 = relay.var("b2") + func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr( + "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_" + ) + expr2 = func2(x, w, b) + b + assert tvm.ir.structural_equal(pattern.partition(expr), expr2) + + def test_match_match(): add_pattern = is_op("add")(wildcard(), wildcard())