Skip to content

Commit

Permalink
[EXEC] Enable inplace sum optimization (apache#3470)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and piiswrong committed Dec 29, 2016
1 parent f9a672f commit 2ed210b
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 30 deletions.
18 changes: 18 additions & 0 deletions src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,24 @@ Graph AttachOpExecs(Graph g);
*/
Graph AttachOpResources(Graph g);

/*!
* \brief Discover chance of inplace addto operators.
* i.e. z = plus(z, source_op), and encourage it to become z += source_op.
*
* This optimization is coupled with executor. This is helpful to reduce memory
* and computation for gradient aggregation of RNN.
*
* Require storage placement to be already finished.
*
* \param g input graph need to contain op_exec attribute.
*
* \return graph two new attributes, changes attribute "storage_id".
* - "addto_entry", std::vector<bool> size=g.num_node_entries()
* - addto_entry[eid] == 1, the corresponding op need to be performed using req=kAddTo
* - "skip_plus_node", std::vector<int> if set to 1, current op's execution is skiped.
*/
Graph DetectInplaceAddTo(Graph g);

} // namespace exec
} // namespace mxnet

Expand Down
60 changes: 49 additions & 11 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,48 @@ const std::vector<NDArray>& GraphExecutor::outputs() const {

nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
using nnvm::Op;
using nnvm::Node;
static size_t inplace_sum_cap = dmlc::GetEnv("MXNET_EXEC_INPLACE_GRAD_SUM_CAP", 8);
static const Op* ewise_plus_op = Op::Get("_ewise_plus");
static const Op* ewise_sum_op = Op::Get("ElementWiseSum");
static const Op* identity_op = Op::Get("identity");

if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
// TODO(tqchen) should be zero node
nnvm::NodePtr ng = Node::Create();
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = Op::Get("_NoGradient");
ng->attrs.name = "NoGradient";
return nnvm::NodeEntry{ng, 0, 0};
} else {
nnvm::NodePtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("ElementWiseSum");
sum_node->attrs.name = "sum_grad";
sum_node->attrs.dict["num_args"] = std::to_string(v.size());
sum_node->attrs.op->attr_parser(&(sum_node->attrs));
sum_node->inputs = std::move(v);
return nnvm::NodeEntry{sum_node, 0, 0};
if (v.size() < inplace_sum_cap) {
nnvm::NodePtr sum_node = nnvm::Node::Create();
sum_node->attrs.op = ewise_sum_op;
sum_node->attrs.name = "sum_grad";
sum_node->attrs.dict["num_args"] = std::to_string(v.size());
sum_node->attrs.op->attr_parser(&(sum_node->attrs));
sum_node->inputs = std::move(v);
return nnvm::NodeEntry{sum_node, 0, 0};
} else {
// use a stream line of plus instead
nnvm::NodeEntry ret = v[0];
for (size_t i = 1; i < v.size(); ++i) {
std::ostringstream os;
os << "sum_grad_" << i;
nnvm::NodePtr x = nnvm::Node::Create();
x->attrs.op = ewise_plus_op;
x->attrs.name = os.str();
x->inputs = {ret, v[i]};
ret = nnvm::NodeEntry{x, 0, 0};
}
// identity node is used to avoid exposure of dummy plus node
// when its output get assigned to another space.
nnvm::NodePtr id_node = nnvm::Node::Create();
id_node->attrs.op = identity_op;
id_node->attrs.name = "sum_grad_final";
id_node->inputs = {ret};
return nnvm::NodeEntry{id_node, 0, 0};
}
}
}

Expand Down Expand Up @@ -331,6 +356,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
g = nnvm::pass::InferType(g, arg_types);
g = nnvm::ApplyPass(g, "PlanMemory");
g = DetectInplaceAddTo(g);
return g;
}

Expand Down Expand Up @@ -438,12 +464,18 @@ void GraphExecutor::InitCachedOps() {
const auto& op_execs =
graph_.GetAttr<OpExecVector>("op_execs");
const auto& vctx = graph_.GetAttr<ContextVector>("context");
const auto& addto_entry = graph_.GetAttr<std::vector<int> >("addto_entry");
const auto& skip_plus_node = graph_.GetAttr<std::vector<int> >("skip_plus_node");

op_nodes_.resize(idx.num_nodes());
// setup the array and requirements.
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
if (skip_plus_node.at(nid)) {
op_nodes_[nid].skip_exec_node = true; continue;
}

op_nodes_[nid].exec = op_execs[nid];
op_nodes_[nid].ctx = vctx[nid];
auto& exec = op_nodes_[nid].exec;
Expand All @@ -456,7 +488,10 @@ void GraphExecutor::InitCachedOps() {
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
exec->out_array.push_back(data_entry_[eid]);
if (vstorage_inplace[eid] >= 0) {
if (addto_entry.at(eid) != 0) {
exec->req.push_back(kAddTo);

} else if (vstorage_inplace[eid] >= 0) {
exec->req.push_back(kWriteInplace);
} else if (vstorage_inplace[eid] == -2) {
// -2 indicate that the entry is never referenced.
Expand All @@ -466,20 +501,22 @@ void GraphExecutor::InitCachedOps() {
}
}
}
// Note that this modifies the requirment of kWriteInplace
for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
auto& e = idx.outputs()[j];
op_nodes_[e.node_id].exec->req[e.index] =
grad_store_[j - num_forward_outputs_].first;
}

for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
if (op_nodes_[nid].skip_exec_node) continue;
auto& exec = op_nodes_[nid].exec;

std::vector<uint32_t> inplace_inputs;
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
// must check exec->req because, vstorage_inplace is only a hint.
if (vstorage_inplace[eid] >= 0 && exec->req.at(index) == kWriteInplace) {
inplace_inputs.push_back(vstorage_inplace[eid]);
}
Expand Down Expand Up @@ -549,6 +586,7 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
for (size_t nid = topo_start; nid < topo_end; ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
if (op_nodes_[nid].skip_exec_node) continue;
OpNode& opnode = op_nodes_[nid];
opnode.exec->op_ctx.is_train = is_train;
if (opnode.exec->exec_type() == Operator::kCrossDeviceCopy) {
Expand Down
2 changes: 2 additions & 0 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class GraphExecutor : public Executor {
Context ctx;
// The executor
std::shared_ptr<OpExecutor> exec;
// skip the execution of this node
bool skip_exec_node{false};
// cached operator handle
Engine::OprHandle cached_opr{nullptr};
};
Expand Down
63 changes: 63 additions & 0 deletions src/executor/inplace_addto_detect_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*!
* Copyright (c) 2016 by Contributors
* \file inplace_addto_detect_pass.cc
* \brief Detect whether inplace addto operation is possible for certain op.
*/
#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <mxnet/op_attr_types.h>
#include <nnvm/graph_attr_types.h>

#include "./exec_pass.h"

namespace mxnet {
namespace exec {

Graph DetectInplaceAddTo(Graph g) {
nnvm::StorageVector storage_id =
g.MoveCopyAttr<nnvm::StorageVector>("storage_id");
std::vector<int> storage_inplace_index =
g.MoveCopyAttr<std::vector<int> >("storage_inplace_index");
static const Op* ewise_plus_op = Op::Get("_ewise_plus");
auto& idx = g.indexed_graph();
// reference cont.
std::vector<int> ref_count(idx.num_node_entries(), 0);
std::vector<int> addto_entry(idx.num_node_entries(), 0);
std::vector<int> skip_plus_node(idx.num_nodes(), 0);

for (auto& e : idx.outputs()) {
++ref_count[idx.entry_id(e)];
}
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
for (auto &e : idx[nid].inputs) {
++ref_count[idx.entry_id(e)];
}
}

for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->op() != ewise_plus_op) continue;
int sid = storage_id[idx.entry_id(inode.inputs[0])];
if (sid != storage_id[idx.entry_id(nid, 0)]) continue;
if (idx[inode.inputs[0].node_id].source->is_variable()) continue;
if (idx[inode.inputs[1].node_id].source->is_variable()) continue;
uint32_t eid_rhs = idx.entry_id(inode.inputs[1]);
if (ref_count[eid_rhs] != 1) continue;
if (inode.inputs[0].node_id >= inode.inputs[1].node_id) continue;
CHECK_NE(storage_id[eid_rhs], sid);
storage_id[eid_rhs] = sid;
addto_entry[eid_rhs] = 1;
storage_inplace_index[eid_rhs] = -1;
skip_plus_node[nid] = 1;
}

g.attrs["storage_id"] = std::make_shared<nnvm::any>(std::move(storage_id));
g.attrs["storage_inplace_index"] = std::make_shared<nnvm::any>(
std::move(storage_inplace_index));
g.attrs["addto_entry"] = std::make_shared<nnvm::any>(std::move(addto_entry));
g.attrs["skip_plus_node"] = std::make_shared<nnvm::any>(std::move(skip_plus_node));
return g;
}

} // namespace exec
} // namespace mxnet
29 changes: 15 additions & 14 deletions src/operator/cudnn_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class CuDNNConvolutionOp : public Operator {
}
Tensor<gpu, 1, DType> workspace =
ctx.requested[conv::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(forward_workspace_), s);
mshadow::Shape1(forward_workspace_), s);

if (param_.kernel.ndim() == 2) {
Tensor<gpu, 4, DType> data = in_data[conv::kData].get<gpu, 4, DType>(s);
Expand All @@ -98,6 +98,7 @@ class CuDNNConvolutionOp : public Operator {
for (uint32_t g = 0; g < param_.num_group; ++g) {
typename DataType<DType>::ScaleType alpha = 1.0f;
typename DataType<DType>::ScaleType beta = 0.0f;
typename DataType<DType>::ScaleType beta_add = 1.0f;
CHECK_EQ(cudnnConvolutionForward(s->dnn_handle_,
&alpha,
in_desc_,
Expand All @@ -108,28 +109,27 @@ class CuDNNConvolutionOp : public Operator {
algo_,
workspace.dptr_,
forward_workspace_byte_,
&beta,
req[conv::kOut] == kAddTo? &beta_add : &beta,
out_desc_,
out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS);
if (!param_.no_bias) {
beta = 1.0f;
Tensor<gpu, 1, DType> bias = in_data[conv::kBias].get<gpu, 1, DType>(s);
#if CUDNN_MAJOR >= 4
CHECK_EQ(cudnnAddTensor(s->dnn_handle_,
&alpha,
bias_desc_,
bias.dptr_ + bias_offset_ * g,
&beta,
out_desc_,
out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS);
&alpha,
bias_desc_,
bias.dptr_ + bias_offset_ * g,
&beta_add,
out_desc_,
out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS);
#endif
#if CUDNN_MAJOR == 3
CHECK_EQ(cudnnAddTensor(s->dnn_handle_,
CUDNN_ADD_SAME_C,
&alpha,
bias_desc_,
bias.dptr_ + bias_offset_ * g,
&beta,
&beta_add,
out_desc_,
out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS);
#endif
Expand Down Expand Up @@ -191,7 +191,7 @@ class CuDNNConvolutionOp : public Operator {
&alpha,
out_desc_,
grad_ptr + out_offset_ * g,
req[conv::kBias] == kWriteTo ? &beta : &beta_add,
req[conv::kBias] == kAddTo ? &beta_add : &beta,
bias_desc_,
gbias.dptr_ + bias_offset_ * g),
CUDNN_STATUS_SUCCESS);
Expand All @@ -207,7 +207,7 @@ class CuDNNConvolutionOp : public Operator {
back_algo_w_,
workspace.dptr_,
backward_workspace_byte_,
req[conv::kWeight] == kWriteTo? &beta : &beta_add,
req[conv::kWeight] == kAddTo? &beta_add : &beta,
filter_desc_,
gwmat_ptr + weight_offset_ * g), CUDNN_STATUS_SUCCESS);
#elif CUDNN_MAJOR == 5
Expand All @@ -221,7 +221,7 @@ class CuDNNConvolutionOp : public Operator {
back_algo_w_,
workspace.dptr_,
backward_workspace_byte_,
req[conv::kWeight] == kWriteTo? &beta : &beta_add,
req[conv::kWeight] == kAddTo? &beta_add : &beta,
filter_desc_,
gwmat_ptr + weight_offset_ * g));
#endif
Expand All @@ -236,6 +236,7 @@ class CuDNNConvolutionOp : public Operator {
back_algo_,
workspace.dptr_,
backward_workspace_byte_,
req[conv::kData] == kAddTo? &beta_add : &beta,
&beta,
in_desc_,
gdata_ptr + data_offset_ * g), CUDNN_STATUS_SUCCESS);
Expand All @@ -250,7 +251,7 @@ class CuDNNConvolutionOp : public Operator {
back_algo_,
workspace.dptr_,
backward_workspace_byte_,
&beta,
req[conv::kData] == kAddTo? &beta_add : &beta,
in_desc_,
gdata_ptr + data_offset_ * g), CUDNN_STATUS_SUCCESS);
#endif
Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_plus)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow::op::plus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_plus"});

// specialized to elementwise plus, currently only used for gradient aggregation
MXNET_OPERATOR_REGISTER_BINARY(_ewise_plus)
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, mshadow::op::plus>);

NNVM_REGISTER_OP(_backward_plus)
.set_num_inputs(1)
.set_num_outputs(2)
Expand Down
8 changes: 6 additions & 2 deletions src/operator/tensor/elemwise_binary_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ NNVM_REGISTER_OP(_plus)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow::op::plus>);

NNVM_REGISTER_OP(_backward_plus)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
mshadow_op::identity>);
.set_attr<FCompute>("FCompute<gpu>",
BinaryBroadcastBackwardUseNone<gpu,
mshadow_op::identity, mshadow_op::identity>);

NNVM_REGISTER_OP(_ewise_plus)
.set_attr<FCompute>("FCompute<gpu>", BinaryCompute<gpu, mshadow::op::plus>);

NNVM_REGISTER_OP(_minus)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow::op::minus>);
Expand Down
6 changes: 4 additions & 2 deletions src/operator/tensor/elemwise_unary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

namespace mxnet {
namespace op {

// copy
MXNET_OPERATOR_REGISTER_UNARY(_copy)
.MXNET_DESCRIBE("Copy src to output")
.set_attr<FCompute>("FCompute<cpu>", UnaryCompute<cpu, mshadow_op::identity>)
.MXNET_DESCRIBE("Identity mapping, copy src to output")
.add_alias("identity")
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});

// negative
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/elemwise_unary_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace mxnet {
namespace op {
// copy
NNVM_REGISTER_OP(_copy)
.set_attr<FCompute>("FCompute<gpu>", UnaryCompute<gpu, mshadow_op::identity>);
.set_attr<FCompute>("FCompute<gpu>", IdentityCompute<gpu>);

// negative
NNVM_REGISTER_OP(negative)
Expand Down
17 changes: 17 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ void UnaryCompute(const nnvm::NodeAttrs& attrs,
});
}


template<typename xpu>
void IdentityCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
if (req[0] == kNullOp || req[0] == kWriteInplace) return;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(out, req[0], F<mshadow_op::identity>(inputs[0].FlatTo1D<xpu, DType>(s)));
});
}

#define MXNET_OPERATOR_REGISTER_UNARY(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
Expand Down

0 comments on commit 2ed210b

Please sign in to comment.