Skip to content

Commit

Permalink
[PASS] Improve graph fusion (#286)
Browse files Browse the repository at this point in the history
* [PASS] Improve graph fusion

* Change fusion center to segment head

* Use 'master' to identity the schedule node

* Make things compact

* Fix
  • Loading branch information
ZihengJiang authored Aug 1, 2017
1 parent 7e82eb6 commit 50ddb76
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions apps/graph_executor/src/graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using nnvm::IndexedGraph;
// The single fuse rule.
enum class FuseRule {
kUknown,
kFuseToParent,
kFuseToMaster,
kRealize
};

Expand Down Expand Up @@ -57,10 +57,16 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
++ref_count[e.node_id];
}
}
for (const auto& e : idx.outputs()) {
// this line will realize all the outputs
ref_count[e.node_id] += 2;
}
// Pattern fo the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern);
// Whether node can be fused to parent.
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
// Master node id of fusion segment.
std::vector<int> master_vec(idx.num_nodes(), -1);
// Operator pattern
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");

Expand All @@ -70,38 +76,58 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
fuse_vec[nid] = FuseRule::kRealize; continue;
}
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern);

if (pt <= kBroadcast) {
// Looking for fusable bcast pattern
int chosen_master = -1;
bool ewise = inode.source->num_outputs() == 1;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
if (pattern_vec[e.node_id] == kBroadcast) {
ewise = false;
fuse_vec[e.node_id] = FuseRule::kFuseToParent;
} else if (pattern_vec[e.node_id] == kElemWise) {
fuse_vec[e.node_id] = FuseRule::kFuseToParent;
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt != kElemWise) ewise = false;
if (ipt <= kBroadcast) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else if (ipt == kComplex && chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else {
fuse_vec[e.node_id] = FuseRule::kRealize;
}
}
if (ewise) {
TShape oshape = shape_vec[idx.entry_id(nid, 0)];
if (oshape != shape_vec[idx.entry_id(e)]) ewise = false;
if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
ewise = false;
}
}
}
pt = ewise ? kElemWise : kBroadcast;
} else if (pt == kComplex) {
master_vec[nid] = chosen_master;
if (chosen_master != -1) {
pt = kComplex;
} else {
pt = ewise ? kElemWise : kBroadcast;
}
} else {
master_vec[nid] = nid;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
if (pattern_vec[e.node_id] <= kBroadcast) {
fuse_vec[e.node_id] = FuseRule::kFuseToParent;
fuse_vec[e.node_id] = FuseRule::kRealize;
if (master_vec[e.node_id] == -1) {
master_vec[e.node_id] = e.node_id;
}
}
}
}

pattern_vec[nid] = pt;
if (ref_count[nid] > 1) {
fuse_vec[nid] = FuseRule::kRealize;
if (master_vec[nid] == -1) {
master_vec[nid] = nid;
}
}
}


// point to the group root id of each node
std::vector<int> group_vec(idx.num_nodes(), -1);
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
Expand All @@ -112,14 +138,15 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
}
// propagate the group id.
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kFuseToParent) {
if (fuse_vec[e.node_id] == FuseRule::kFuseToMaster) {
CHECK(group_vec[e.node_id] == -1||
group_vec[e.node_id] == group_vec[nid]);
group_vec[e.node_id] = group_vec[nid];
}
}
}
g.attrs["group_root"] = std::make_shared<any>(std::move(group_vec));
g.attrs["group_master"] = std::make_shared<any>(std::move(master_vec));
g.attrs["pattern"] = std::make_shared<any>(std::move(pattern_vec));
g.attrs["dltype"] = std::make_shared<any>(std::move(dltype_vec));
return g;
Expand Down Expand Up @@ -172,6 +199,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
const DLTypeVector& dltype_vec = g.GetAttr<DLTypeVector>("dltype");
const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype");
const std::vector<int>& group_vec = g.GetAttr<std::vector<int> >("group_root");
const std::vector<int>& master_vec = g.GetAttr<std::vector<int> >("group_master");
const std::vector<TOpPattern>& pattern_vec =
g.GetAttr<std::vector<TOpPattern> >("pattern");
std::string target = g.GetAttr<std::string>("target");
Expand Down Expand Up @@ -239,15 +267,18 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, inputs);
CHECK_EQ(out.size(), inode.source->num_outputs());

// schedule on root node, and use master's schedule
if (nid != root_id) {
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
tensor_vec[eid] = out[index];
}
} else {
// Work on schedule
fe.outputs = out;
fe.schedule = fschedule[inode.source->op()](
int master = master_vec[root_id];
CHECK_GE(master, 0);
fe.schedule = fschedule[idx[master].source->op()](
inode.source->attrs, fe.outputs, target);
std::ostringstream os;
os << inode.source->attrs.name + "_id" << nid;
Expand Down Expand Up @@ -307,10 +338,12 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
old_new[nid] = np;
}
}

nnvm::Graph ret;
for (const auto& e : idx.outputs()) {
auto it = old_new.find(e.node_id);
CHECK(it != old_new.end());
auto it = old_new.find(group_vec[e.node_id]);
CHECK(it != old_new.end())
<< "cannot find node_id=" << e.node_id;
ret.outputs.emplace_back(
nnvm::NodeEntry{it->second, e.index, e.version});
}
Expand Down

0 comments on commit 50ddb76

Please sign in to comment.