Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS] Improve graph fusion #286

Merged
merged 10 commits into from
Aug 1, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions apps/graph_executor/src/graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using nnvm::IndexedGraph;
// The single fuse rule.
enum class FuseRule {
kUknown,
kFuse,
kFuseToMaster,
kRealize
};

Expand Down Expand Up @@ -68,6 +68,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
std::vector<int> master_vec(idx.num_nodes(), -1);
// Operator pattern
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
auto same_shape = [&] (uint32_t leid, uint32_t reid) { return shape_vec[leid] == shape_vec[reid]; };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too short that maybe we don't need this lambda


for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
Expand All @@ -77,43 +78,42 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern);

if (pt <= kBroadcast) {
// multiple complex inputs, fuse current node into one of them;
for (const auto& e : inode.inputs) {
if (master_vec[e.node_id] != -1) {
master_vec[nid] = master_vec[e.node_id];
break;
}
}
int chosen_master = -1;
bool ewise = inode.source->num_outputs() == 1;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
if (master_vec[e.node_id] != -1 &&
master_vec[e.node_id] != master_vec[nid]) {
ewise = false;
fuse_vec[e.node_id] = FuseRule::kRealize;
} else {
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt != kElemWise) {
ewise = false;
}
if (ipt != kExtern) {
fuse_vec[e.node_id] = FuseRule::kFuse;
}
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt != kElemWise) ewise = false;
switch (ipt) {
case kElemWise:
case kBroadcast:
if (master_vec[e.node_id] == -1) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to rely on master flag here>

break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will result in a bug when we first have input that is ewise, then another that is complex

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, ignore my previous comment, this seems to be fine, as the input can also be fused, please add comment here to indicate this case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always put break outside if block

}
case kComplex:
if (chosen_master == -1 && same_shape(idx.entry_id(nid, 0), idx.entry_id(e))) {
chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the pattern is complex, the final output pattern need to be kComplex

break;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relies on pass through behavior, which is not quite intuitive. Always put break in outside, and add an else case

default:
fuse_vec[e.node_id] = FuseRule::kRealize;
}
}
if (ewise) {
TShape oshape = shape_vec[idx.entry_id(nid, 0)];
if (oshape != shape_vec[idx.entry_id(e)]) ewise = false;
if (!same_shape(idx.entry_id(nid, 0), idx.entry_id(e))) ewise = false;
}
}
pt = ewise ? kElemWise : kBroadcast;
master_vec[nid] = chosen_master;
} else {
master_vec[nid] = nid;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
fuse_vec[e.node_id] = FuseRule::kRealize;
if (master_vec[e.node_id] == -1) {
master_vec[e.node_id] = nid;
master_vec[e.node_id] = e.node_id;
}
}
}
Expand All @@ -139,7 +139,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
}
// propagate the group id.
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kFuse) {
if (fuse_vec[e.node_id] == FuseRule::kFuseToMaster) {
CHECK(group_vec[e.node_id] == -1||
group_vec[e.node_id] == group_vec[nid]);
group_vec[e.node_id] = group_vec[nid];
Expand Down