Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a FunctionPattern, remove unused attributes in CallPattern #7151

Merged
merged 2 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 43 additions & 26 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,34 +148,9 @@ class CallPatternNode : public DFPatternNode {
/*! \brief The arguments(inputs) of the call */
tvm::Array<relay::DFPattern> args;

/*! \brief The additional attributes */
Attrs attrs;

/*!
* \brief The type arguments passed to polymorphic(template) function.
*
* This is the advance feature that is only used when the function is
* polymorphic. It is safe to be ignored in most cases. For example, in the
* following code, the type_args of addone call is [int].
*
* \code
*
* template<typename T>
* T addone(T a) { return a + 1; }
*
* void main() {
* int x = addone<int>(10);
* }
*
* \endcode
*/
tvm::Array<Type> type_args;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern";
Expand All @@ -184,10 +159,52 @@ class CallPatternNode : public DFPatternNode {

class CallPattern : public DFPattern {
public:
TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args);
TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args);
TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode);
};

/*!
* \brief Relay Function container
* \sa Function
*/
class FunctionPatternNode : public DFPatternNode {
public:
/*! \brief Function parameters */
tvm::Array<DFPattern> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
DFPattern body;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.FunctionPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode);
};

/*!
* \brief Managed reference to FunctionNode.
* \sa FunctionNode
*/
class FunctionPattern : public DFPattern {
public:
/*!
* \brief Constructor
* \param params The parameters of the function.
* \param body The body of the function.
*/
TVM_DLL FunctionPattern(tvm::Array<DFPattern> params, DFPattern body);

TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
};

/*! \brief Tuple of multiple Exprs */
class TuplePattern;
/*! \brief Tuple container */
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
Expand All @@ -112,6 +113,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
Expand All @@ -138,6 +140,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const DataTypePatternNode* op) override;
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const FunctionPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
Expand Down
34 changes: 23 additions & 11 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,24 +504,36 @@ class CallPattern(DFPattern):
args: List[realy.dataflow_pattern.DFPattern]
The arguments to the call.

attrs: Optional[tvm.ir.attrs.Attrs]
Attributes to the call, can be None

type_args: Optional[List[tvm.ir.type.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""

def __init__(
self,
op: "DFPattern",
args: List["DFPattern"],
attrs: Optional[tvm.ir.attrs.Attrs] = None,
type_args: Optional[List[tvm.ir.type.Type]] = None,
):
if not type_args:
type_args = []
self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args)
self.__init_handle_by_constructor__(ffi.CallPattern, op, args)


@register_df_node
class FunctionPattern(DFPattern):
"""A pattern matching a function node in Relay.

Parameters
----------
params: List[realy.dataflow_pattern.DFPattern]
The parameters to the Function.

body: realy.dataflow_pattern.DFPattern
The body fo the Function

"""

def __init__(
self,
params: List["DFPattern"],
body: "DFPattern",
):
self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body)


@register_df_node
Expand Down
42 changes: 34 additions & 8 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
Expand Down Expand Up @@ -264,10 +265,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
is_expr_op(call_node->args[1], "divide"))) {
bool out = false;
for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs,
op->type_args);
auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
arg_node->attrs, arg_node->type_args);
auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]});
auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div});
out = VisitDFPattern(mul, expr);
if (out) {
return true;
Expand All @@ -286,10 +285,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
(is_expr_op(call_node->args[0], "multiply") ||
is_expr_op(call_node->args[1], "multiply"))) {
auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
op->attrs, op->type_args);
auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs,
arg_node->type_args);
auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]});
auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]});
return VisitDFPattern(div, expr);
}
}
Expand Down Expand Up @@ -356,6 +353,26 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex
return StructuralEqual()(op->expr, expr);
}

bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* func = expr.as<FunctionNode>()) {
matches = true;
size_t i = 0;
if (op->params.size() == func->params.size()) {
while (matches && i < op->params.size()) {
matches &= VisitDFPattern(op->params[i], func->params[i]);
++i;
}
} else {
matches = false;
}
if (matches) {
matches &= VisitDFPattern(op->body, func->body);
}
}
return matches;
}

bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
Expand Down Expand Up @@ -601,13 +618,22 @@ class PatternGrouper {
// Get fuzzy patterns
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
for (auto node : pattern_graph_.topological_order_) {
// Don't treat fuzzy Dominator patterns input variables for partition
if (auto op = node->ref_.as<DominatorPatternNode>()) {
for (auto fuzzy_op : {op->parent, op->path}) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
}
}
}
// Don't treat Function params as input variables for partition
if (auto op = node->ref_.as<FunctionPatternNode>()) {
for (auto fuzzy_op : op->params) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
}
}
}
}

// Create input variables
Expand Down
30 changes: 22 additions & 8 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,41 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "ConstantPattern()";
});

CallPattern::CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args) {
CallPattern::CallPattern(DFPattern op, Array<DFPattern> args) {
ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>();
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(CallPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern")
.set_body_typed([](DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args) {
return CallPattern(op, args, attrs, type_args);
});
.set_body_typed([](DFPattern op, Array<DFPattern> args) { return CallPattern(op, args); });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const CallPatternNode*>(ref.get());
p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs
<< ", " << node->type_args << ")";
p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")";
});

FunctionPattern::FunctionPattern(Array<DFPattern> params, DFPattern body) {
ObjectPtr<FunctionPatternNode> n = make_object<FunctionPatternNode>();
n->params = std::move(params);
n->body = std::move(body);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(FunctionPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern")
.set_body_typed([](Array<DFPattern> params, DFPattern body) {
return FunctionPattern(params, body);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionPatternNode*>(ref.get());
p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
});

TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) {
Expand Down
7 changes: 7 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {

void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {
for (auto param : op->params) {
VisitDFPattern(param);
}
VisitDFPattern(op->body);
}

void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
Expand Down
7 changes: 7 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {

void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override {
for (auto param : op->params) {
VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
}
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SimplifyReshape {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {});
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_})});
}

Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
Expand Down
Loading