-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PASS] Improve graph fusion #286
Changes from 1 commit
d9aad57
02cb6b3
eae2217
b121e6e
26f03d9
1d6de5a
a64f1f6
251f583
0c2f51a
e9aef2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ using nnvm::IndexedGraph; | |
// The single fuse rule. | ||
enum class FuseRule { | ||
kUknown, | ||
kFuse, | ||
kFuseToMaster, | ||
kRealize | ||
}; | ||
|
||
|
@@ -68,6 +68,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) { | |
std::vector<int> master_vec(idx.num_nodes(), -1); | ||
// Operator pattern | ||
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern"); | ||
auto same_shape = [&] (uint32_t leid, uint32_t reid) { return shape_vec[leid] == shape_vec[reid]; }; | ||
|
||
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { | ||
const auto& inode = idx[nid]; | ||
|
@@ -77,43 +78,42 @@ nnvm::Graph GraphPartition(nnvm::Graph g) { | |
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern); | ||
|
||
if (pt <= kBroadcast) { | ||
// multiple complex inputs, fuse current node into one of them; | ||
for (const auto& e : inode.inputs) { | ||
if (master_vec[e.node_id] != -1) { | ||
master_vec[nid] = master_vec[e.node_id]; | ||
break; | ||
} | ||
} | ||
int chosen_master = -1; | ||
bool ewise = inode.source->num_outputs() == 1; | ||
for (const auto& e : inode.inputs) { | ||
if (fuse_vec[e.node_id] == FuseRule::kUknown) { | ||
if (master_vec[e.node_id] != -1 && | ||
master_vec[e.node_id] != master_vec[nid]) { | ||
ewise = false; | ||
fuse_vec[e.node_id] = FuseRule::kRealize; | ||
} else { | ||
TOpPattern ipt = pattern_vec[e.node_id]; | ||
if (ipt != kElemWise) { | ||
ewise = false; | ||
} | ||
if (ipt != kExtern) { | ||
fuse_vec[e.node_id] = FuseRule::kFuse; | ||
} | ||
TOpPattern ipt = pattern_vec[e.node_id]; | ||
if (ipt != kElemWise) ewise = false; | ||
switch (ipt) { | ||
case kElemWise: | ||
case kBroadcast: | ||
if (master_vec[e.node_id] == -1) { | ||
fuse_vec[e.node_id] = FuseRule::kFuseToMaster; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to rely on master flag here> |
||
break; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will result in a bug when we first have input that is ewise, then another that is complex There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, ignore my previous comment, this seems to be fine, as the input can also be fused, please add comment here to indicate this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. always put break outside if block |
||
} | ||
case kComplex: | ||
if (chosen_master == -1 && same_shape(idx.entry_id(nid, 0), idx.entry_id(e))) { | ||
chosen_master = master_vec[e.node_id]; | ||
fuse_vec[e.node_id] = FuseRule::kFuseToMaster; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the pattern is complex, the final output pattern need to be kComplex |
||
break; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This relies on pass through behavior, which is not quite intuitive. Always put break in outside, and add an else case |
||
default: | ||
fuse_vec[e.node_id] = FuseRule::kRealize; | ||
} | ||
} | ||
if (ewise) { | ||
TShape oshape = shape_vec[idx.entry_id(nid, 0)]; | ||
if (oshape != shape_vec[idx.entry_id(e)]) ewise = false; | ||
if (!same_shape(idx.entry_id(nid, 0), idx.entry_id(e))) ewise = false; | ||
} | ||
} | ||
pt = ewise ? kElemWise : kBroadcast; | ||
master_vec[nid] = chosen_master; | ||
} else { | ||
master_vec[nid] = nid; | ||
for (const auto& e : inode.inputs) { | ||
if (fuse_vec[e.node_id] == FuseRule::kUknown) { | ||
fuse_vec[e.node_id] = FuseRule::kRealize; | ||
if (master_vec[e.node_id] == -1) { | ||
master_vec[e.node_id] = nid; | ||
master_vec[e.node_id] = e.node_id; | ||
} | ||
} | ||
} | ||
|
@@ -139,7 +139,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) { | |
} | ||
// propagate the group id. | ||
for (const auto& e : inode.inputs) { | ||
if (fuse_vec[e.node_id] == FuseRule::kFuse) { | ||
if (fuse_vec[e.node_id] == FuseRule::kFuseToMaster) { | ||
CHECK(group_vec[e.node_id] == -1|| | ||
group_vec[e.node_id] == group_vec[nid]); | ||
group_vec[e.node_id] = group_vec[nid]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is too short that maybe we don't need this lambda