Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS] Improve graph fusion #286

Merged
merged 10 commits into from
Aug 1, 2017
Merged

[PASS] Improve graph fusion #286

merged 10 commits into from
Aug 1, 2017

Conversation

ZihengJiang
Copy link
Contributor

@ZihengJiang ZihengJiang commented Jul 30, 2017

  • ref_count should consider graph outputs. If an intermediate tensor is an input of another node also the graph output, it should not be fused.
  • For schedule, use master's schedule instead of group output node's schedule

@@ -218,6 +221,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
nnvm::Op::GetAttr<FTVMCompute>("FTVMCompute");
static auto& fschedule =
nnvm::Op::GetAttr<FTVMSchedule>("FTVMSchedule");
std::unordered_map<uint32_t, std::vector<Operation>> group_ops;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space between >> for vs compilers

@@ -68,6 +68,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
std::vector<int> master_vec(idx.num_nodes(), -1);
// Operator pattern
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
auto same_shape = [&] (uint32_t leid, uint32_t reid) { return shape_vec[leid] == shape_vec[reid]; };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too short that maybe we don't need this lambda

case kBroadcast:
if (master_vec[e.node_id] == -1) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will result in a bug when we first have input that is ewise, then another that is complex

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, ignore my previous comment, this seems to be fine, as the input can also be fused, please add comment here to indicate this case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always put break outside if block

chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
break;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relies on pass through behavior, which is not quite intuitive. Always put break in outside, and add an else case

case kComplex:
if (chosen_master == -1 && same_shape(idx.entry_id(nid, 0), idx.entry_id(e))) {
chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the pattern is complex, the final output pattern need to be kComplex

case kElemWise:
case kBroadcast:
if (master_vec[e.node_id] == -1) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to rely on master flag here>

@ZihengJiang ZihengJiang merged commit 50ddb76 into apache:master Aug 1, 2017
@ZihengJiang ZihengJiang deleted the dev branch August 1, 2017 05:28
tqchen pushed a commit to tqchen/tvm that referenced this pull request May 26, 2018
tqchen pushed a commit that referenced this pull request May 29, 2018
tqchen pushed a commit to tqchen/tvm that referenced this pull request Jul 6, 2018
sergei-mironov pushed a commit to sergei-mironov/tvm that referenced this pull request Aug 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants