Skip to content

Commit

Permalink
Add a FunctionPattern, remove unused attributes in CallPattern (apach…
Browse files Browse the repository at this point in the history
…e#7151)

* Add a FunctionPattern, remove unused attributes in CallPattern

* update docs
  • Loading branch information
Matthew Brookhart authored and electriclilies committed Feb 18, 2021
1 parent cd5ba0f commit e5b744b
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 54 deletions.
19 changes: 19 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu. Note tha
out = relay.nn.relu(tuple_get_item_node)
pat.match(out)
If we have a pattern that crosses a function boundary, we might want to match the Function itself


.. code-block:: python
def test_match_func():
x = relay.var("x")
y = relay.var("y")
wc1 = wildcard()
wc2 = wildcard()
func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
assert func_pattern.match(relay.Function([x, y], x + y))
The next example is matching a constant node regarding its values. This is useful to check
if a specific parameter in a subgraph has been bound or not.

Expand Down Expand Up @@ -283,6 +296,7 @@ The high level design is to introduce a language of patterns for now we propose
| is_tuple_get_item(pattern, index = None)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
| FunctionPattern(params, body)

The above language then provides a matching interface with both can select sub-graphs as well as verify that the graph does match the pattern.

Expand Down Expand Up @@ -332,6 +346,11 @@ Domination

Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern.

Function Pattern
****************

Match a Function with a body and parameters

Applications
============

Expand Down
69 changes: 43 additions & 26 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,34 +148,9 @@ class CallPatternNode : public DFPatternNode {
/*! \brief The arguments(inputs) of the call */
tvm::Array<relay::DFPattern> args;

/*! \brief The additional attributes */
Attrs attrs;

/*!
* \brief The type arguments passed to polymorphic(template) function.
*
* This is the advance feature that is only used when the function is
* polymorphic. It is safe to be ignored in most cases. For example, in the
* following code, the type_args of addone call is [int].
*
* \code
*
* template<typename T>
* T addone(T a) { return a + 1; }
*
* void main() {
* int x = addone<int>(10);
* }
*
* \endcode
*/
tvm::Array<Type> type_args;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern";
Expand All @@ -184,10 +159,52 @@ class CallPatternNode : public DFPatternNode {

class CallPattern : public DFPattern {
public:
TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args);
TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args);
TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode);
};

/*!
* \brief Relay Function container
* \sa Function
*/
class FunctionPatternNode : public DFPatternNode {
public:
/*! \brief Function parameters */
tvm::Array<DFPattern> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
DFPattern body;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.FunctionPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode);
};

/*!
* \brief Managed reference to FunctionNode.
* \sa FunctionNode
*/
class FunctionPattern : public DFPattern {
public:
/*!
* \brief Constructor
* \param params The parameters of the function.
* \param body The body of the function.
*/
TVM_DLL FunctionPattern(tvm::Array<DFPattern> params, DFPattern body);

TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
};

/*! \brief Tuple of multiple Exprs */
class TuplePattern;
/*! \brief Tuple container */
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
Expand All @@ -112,6 +113,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
Expand All @@ -138,6 +140,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const DataTypePatternNode* op) override;
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const FunctionPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
Expand Down
34 changes: 23 additions & 11 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,24 +504,36 @@ class CallPattern(DFPattern):
args: List[realy.dataflow_pattern.DFPattern]
The arguments to the call.
attrs: Optional[tvm.ir.attrs.Attrs]
Attributes to the call, can be None
type_args: Optional[List[tvm.ir.type.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""

def __init__(
self,
op: "DFPattern",
args: List["DFPattern"],
attrs: Optional[tvm.ir.attrs.Attrs] = None,
type_args: Optional[List[tvm.ir.type.Type]] = None,
):
if not type_args:
type_args = []
self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args)
self.__init_handle_by_constructor__(ffi.CallPattern, op, args)


@register_df_node
class FunctionPattern(DFPattern):
"""A pattern matching a function node in Relay.
Parameters
----------
params: List[realy.dataflow_pattern.DFPattern]
The parameters to the Function.
body: realy.dataflow_pattern.DFPattern
The body fo the Function
"""

def __init__(
self,
params: List["DFPattern"],
body: "DFPattern",
):
self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body)


@register_df_node
Expand Down
42 changes: 34 additions & 8 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
Expand Down Expand Up @@ -264,10 +265,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
is_expr_op(call_node->args[1], "divide"))) {
bool out = false;
for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs,
op->type_args);
auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
arg_node->attrs, arg_node->type_args);
auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]});
auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div});
out = VisitDFPattern(mul, expr);
if (out) {
return true;
Expand All @@ -286,10 +285,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
(is_expr_op(call_node->args[0], "multiply") ||
is_expr_op(call_node->args[1], "multiply"))) {
auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
op->attrs, op->type_args);
auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs,
arg_node->type_args);
auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]});
auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]});
return VisitDFPattern(div, expr);
}
}
Expand Down Expand Up @@ -356,6 +353,26 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex
return StructuralEqual()(op->expr, expr);
}

bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* func = expr.as<FunctionNode>()) {
matches = true;
size_t i = 0;
if (op->params.size() == func->params.size()) {
while (matches && i < op->params.size()) {
matches &= VisitDFPattern(op->params[i], func->params[i]);
++i;
}
} else {
matches = false;
}
if (matches) {
matches &= VisitDFPattern(op->body, func->body);
}
}
return matches;
}

bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
Expand Down Expand Up @@ -601,13 +618,22 @@ class PatternGrouper {
// Get fuzzy patterns
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
for (auto node : pattern_graph_.topological_order_) {
// Don't treat fuzzy Dominator patterns input variables for partition
if (auto op = node->ref_.as<DominatorPatternNode>()) {
for (auto fuzzy_op : {op->parent, op->path}) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
}
}
}
// Don't treat Function params as input variables for partition
if (auto op = node->ref_.as<FunctionPatternNode>()) {
for (auto fuzzy_op : op->params) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
}
}
}
}

// Create input variables
Expand Down
30 changes: 22 additions & 8 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,41 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "ConstantPattern()";
});

CallPattern::CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args) {
CallPattern::CallPattern(DFPattern op, Array<DFPattern> args) {
ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>();
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(CallPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern")
.set_body_typed([](DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args) {
return CallPattern(op, args, attrs, type_args);
});
.set_body_typed([](DFPattern op, Array<DFPattern> args) { return CallPattern(op, args); });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const CallPatternNode*>(ref.get());
p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs
<< ", " << node->type_args << ")";
p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")";
});

FunctionPattern::FunctionPattern(Array<DFPattern> params, DFPattern body) {
ObjectPtr<FunctionPatternNode> n = make_object<FunctionPatternNode>();
n->params = std::move(params);
n->body = std::move(body);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(FunctionPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern")
.set_body_typed([](Array<DFPattern> params, DFPattern body) {
return FunctionPattern(params, body);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionPatternNode*>(ref.get());
p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
});

TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) {
Expand Down
7 changes: 7 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {

void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {
for (auto param : op->params) {
VisitDFPattern(param);
}
VisitDFPattern(op->body);
}

void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
Expand Down
7 changes: 7 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {

void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override {
for (auto param : op->params) {
VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
}
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SimplifyReshape {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {});
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_})});
}

Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
Expand Down
Loading

0 comments on commit e5b744b

Please sign in to comment.