Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor save_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Aug 24, 2017
1 parent 4970caf commit 2328fa9
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 87 deletions.
2 changes: 2 additions & 0 deletions src/c_api/c_api_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ struct MXAPIThreadLocalEntry {
std::vector<const mx_uint*> arg_shape_data, out_shape_data, aux_shape_data;
/*! \brief uint32_t buffer for returning shape pointer */
std::vector<uint32_t> arg_shape_buffer, out_shape_buffer, aux_shape_buffer;
/*! \brief bool buffer */
std::vector<bool> save_inputs, save_outputs;
// helper function to setup return value of shape array
inline static void SetupShapeArrayReturnWithBuffer(
const std::vector<TShape> &shapes,
Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs,
attrs.parsed = params;
// TODO(piiswrong): remove state by using FComputeEx
auto state = OpStatePtr::Create<CustomFunctionParam>(params);
AutogradRuntime::Get()->RecordImperativeOperator(
state, attrs.op, attrs, &ndinputs, &ndoutputs);
AutogradRuntime::Get()->RecordOp(
std::move(attrs), &ndinputs, &ndoutputs, state);

for (size_t i = 0; i < ndoutputs.size(); ++i) {
*reinterpret_cast<NDArray*>(outputs[i]) = ndoutputs[i];
Expand Down
56 changes: 40 additions & 16 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,11 @@ void PushOperator(const OpStatePtr& state,
}

void ImperativeInvokeImpl(const Context& default_ctx,
const nnvm::NodeAttrs& attrs,
nnvm::NodeAttrs&& attrs,
std::vector<NDArray>* p_ndinputs,
std::vector<NDArray>* p_ndoutputs) {
std::vector<NDArray>* p_ndoutputs,
std::vector<bool>* p_save_inputs = nullptr,
std::vector<bool>* p_save_outputs = nullptr) {
static auto& ndfunc = nnvm::Op::GetAttr<FNDArrayFunction>("FNDArrayFunction");
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
Expand Down Expand Up @@ -514,29 +516,32 @@ void ImperativeInvokeImpl(const Context& default_ctx,
FCompute fn = common::GetFCompute<FCompute>(op, "FCompute", ctx);
FComputeEx fn_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", ctx);
if (fn_ex && stype != kDefaultStorage) {
if (AutogradRuntime::Get()->IsRecording()) {
AutogradRuntime::Get()->RecordImperativeFCompute(op,
attrs, &ndinputs, &ndoutputs);
}
PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars,
requested, ndinputs, ndoutputs);
} else if (fn) {
if (AutogradRuntime::Get()->IsRecording()) {
AutogradRuntime::Get()->RecordImperativeFCompute(op,
attrs, &ndinputs, &ndoutputs);
AutogradRuntime::Get()->RecordOp(
std::move(attrs), &ndinputs, &ndoutputs, OpStatePtr(),
p_save_inputs, p_save_outputs);
}
} else if (fn) {
PushFCompute(fn, op, attrs, ctx, read_vars, write_vars,
requested, ndinputs, ndoutputs, mutate_idx);
if (AutogradRuntime::Get()->IsRecording()) {
AutogradRuntime::Get()->RecordOp(
std::move(attrs), &ndinputs, &ndoutputs, OpStatePtr(),
p_save_inputs, p_save_outputs);
}
} else if (createop.count(op)) {
auto state =
createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types);
if (AutogradRuntime::Get()->IsRecording()) {
AutogradRuntime::Get()->RecordImperativeOperator(state, op,
attrs, &ndinputs, &ndoutputs);
}
write_vars.push_back(state.get_var());
PushOperator(state, op, attrs, ctx, read_vars, write_vars,
requested, ndinputs, ndoutputs, mutate_idx);
if (AutogradRuntime::Get()->IsRecording()) {
AutogradRuntime::Get()->RecordOp(
std::move(attrs), &ndinputs, &ndoutputs, state,
p_save_inputs, p_save_outputs);
}
} else {
LOG(FATAL)
<< "Operator " << op->name << " is not implemented for "
Expand Down Expand Up @@ -569,7 +574,7 @@ int MXImperativeInvoke(AtomicSymbolCreator creator,
SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs,
num_outputs, infered_num_outputs, num_visible_outputs, outarray);

ImperativeInvokeImpl(Context::CPU(), attrs, &ndinputs, &ndoutputs);
ImperativeInvokeImpl(Context::CPU(), std::move(attrs), &ndinputs, &ndoutputs);

if (outarray == nullptr) {
ret->ret_handles.clear();
Expand Down Expand Up @@ -618,6 +623,20 @@ int MXCreateCachedOp(SymbolHandle handle,
auto vars = sym->ListInputs(nnvm::Symbol::kAll);
CHECK_GE(vars.size(), 1) << "CachedOp must have at least 1 input.";
g->attrs["vars"] = std::make_shared<dmlc::any>(std::move(vars));

const nnvm::IndexedGraph& idx = g->indexed_graph();
std::vector<std::vector<bool> > save_inputs(idx.num_nodes());
std::vector<std::vector<bool> > save_outputs(idx.num_nodes());
for (size_t i = 0; i < idx.num_nodes(); ++i) {
nnvm::NodePtr node = nnvm::Node::Create();
node->attrs = idx[i].source->attrs;
AutogradRuntime::Get()->GetBackwardDependency(
node, idx[i].source->num_inputs(), idx[i].source->num_outputs(),
&save_inputs[i], &save_outputs[i]);
}
g->attrs["save_inputs"] = std::make_shared<dmlc::any>(std::move(save_inputs));
g->attrs["save_outputs"] = std::make_shared<dmlc::any>(std::move(save_outputs));

*out = g;
API_END();
}
Expand All @@ -640,7 +659,11 @@ int MXInvokeCachedOp(CachedOpHandle handle,

API_BEGIN();
const std::vector<nnvm::NodePtr>& vars =
g->GetAttr<std::vector<nnvm::NodePtr> >("vars");
g->GetAttr<std::vector<nnvm::NodePtr> >("vars");
std::vector<std::vector<bool> > save_inputs =
g->GetAttr<std::vector<std::vector<bool> > >("save_inputs");
std::vector<std::vector<bool> > save_outputs =
g->GetAttr<std::vector<std::vector<bool> > >("save_outputs");
const nnvm::IndexedGraph& idx = g->indexed_graph();
CHECK_EQ(static_cast<size_t>(num_inputs), vars.size())
<< "Actually number of inputs differs from expected number of inputs";
Expand All @@ -661,7 +684,8 @@ int MXInvokeCachedOp(CachedOpHandle handle,
in.emplace_back(buff[idx.entry_id(j)]);
}
std::vector<NDArray> out(node.source->num_outputs());
ImperativeInvokeImpl(default_ctx, node.source->attrs, &in, &out);
ImperativeInvokeImpl(default_ctx, nnvm::NodeAttrs(node.source->attrs), &in, &out,
&save_inputs[i], &save_outputs[i]);

for (size_t j = 0; j < node.source->num_outputs(); ++j) {
buff[idx.entry_id(i, j)] = std::move(out[j]);
Expand Down
110 changes: 62 additions & 48 deletions src/ndarray/autograd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <iostream>
#include "../executor/graph_executor.h"
#include "./autograd.h"
#include "../c_api/c_api_common.h"

namespace mxnet {
namespace autograd {
Expand Down Expand Up @@ -101,21 +102,6 @@ void AutogradRuntime::MarkVariables(
}
}

void AutogradRuntime::RecordImperativeFCompute(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs) {
RecordOp(op, attrs, p_inputs, p_outputs, OpStatePtr());
}

void AutogradRuntime::RecordImperativeOperator(const OpStatePtr& state,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs) {
RecordOp(op, attrs, p_inputs, p_outputs, state);
}

std::shared_ptr<AutogradRuntime> AutogradRuntime::_GetSharedRef() {
static std::shared_ptr<AutogradRuntime> inst(new AutogradRuntime());
return inst;
Expand All @@ -126,12 +112,58 @@ AutogradRuntime* AutogradRuntime::Get() {
return ptr;
}

void AutogradRuntime::RecordOp(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
void AutogradRuntime::GetBackwardDependency(const nnvm::NodePtr& node,
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs) {
static auto& fgradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
std::vector<bool>& save_inputs = *p_save_inputs;
std::vector<bool>& save_outputs = *p_save_outputs;
save_inputs.resize(num_inputs);
save_outputs.resize(num_outputs);
std::fill(save_inputs.begin(), save_inputs.end(), false);
std::fill(save_outputs.begin(), save_outputs.end(), false);

node->inputs.clear();
node->inputs.reserve(num_inputs);
for (uint32_t i = 0; i < num_inputs; ++i) {
node->inputs.emplace_back(NodeEntry{nullptr, i, 0});
}

if (fgradient.count(node->op())) {
std::vector<NodeEntry> ograd_entries;
ograd_entries.reserve(num_outputs);
for (uint32_t i = 0; i < num_outputs; ++i) {
ograd_entries.emplace_back(NodeEntry{nullptr, i, 1});
}
auto igrad_entries = fgradient[node->op()](node, ograd_entries);
for (const auto& i : igrad_entries) {
if (i.node == nullptr && i.version == 0) {
save_inputs[i.index] = true;
} else if (i.node == node) {
save_outputs[i.index] = true;
}
}
DFSVisit(igrad_entries, [&](const NodePtr& gnode) {
if (!gnode || gnode == node) return;
for (const auto& i : gnode->inputs) {
if (i.node == nullptr && i.version == 0) {
save_inputs[i.index] = true;
} else if (i.node == node) {
save_outputs[i.index] = true;
}
}
});
}
}

void AutogradRuntime::RecordOp(nnvm::NodeAttrs&& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs,
const OpStatePtr& state) {
static auto& fgradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
const OpStatePtr& state,
std::vector<bool>* p_save_inputs,
std::vector<bool>* p_save_outputs) {
MXAPIThreadLocalEntry *local_buff = MXAPIThreadLocalStore::Get();
std::vector<NDArray>& inputs = *p_inputs;
std::vector<NDArray>& outputs = *p_outputs;

Expand All @@ -154,39 +186,21 @@ void AutogradRuntime::RecordOp(const nnvm::Op* op,
if (!need_grad) return;

NodePtr nn_node = Node::Create();
nn_node->attrs = attrs;
nn_node->attrs = std::move(attrs);
nn_node->attrs.name = "node_" + std::to_string(node_count_++);

// Get backward dependency
std::vector<bool> save_inputs(inputs.size()), save_outputs(outputs.size());
for (uint32_t i = 0; i < inputs.size(); ++i) {
nn_node->inputs.emplace_back(NodeEntry{nullptr, i, 0});
}
if (fgradient.count(attrs.op)) {
std::vector<NodeEntry> ograd_entries;
for (uint32_t i = 0; i < outputs.size(); ++i) {
ograd_entries.emplace_back(NodeEntry{nullptr, i, 1});
}
auto igrad_entries = fgradient[nn_node->op()](nn_node, ograd_entries);
for (const auto& i : igrad_entries) {
if (i.node == nullptr && i.version == 0) {
save_inputs[i.index] = true;
} else if (i.node == nn_node) {
save_outputs[i.index] = true;
}
}
DFSVisit(igrad_entries, [&](const NodePtr& node) {
if (!node || node == nn_node) return;
for (const auto& i : node->inputs) {
if (i.node == nullptr && i.version == 0) {
save_inputs[i.index] = true;
} else if (i.node == nn_node) {
save_outputs[i.index] = true;
}
}
});
if (p_save_inputs == nullptr) {
p_save_inputs = &(local_buff->save_inputs);
p_save_outputs = &(local_buff->save_outputs);
GetBackwardDependency(
nn_node, inputs.size(), outputs.size(), p_save_inputs, p_save_outputs);
} else {
nn_node->inputs.resize(inputs.size());
}

std::vector<bool>& save_inputs = *p_save_inputs;
std::vector<bool>& save_outputs = *p_save_outputs;

AGNodePtr ag_node = AGNode::Create(nn_node);
ag_node->state = state;

Expand Down
30 changes: 13 additions & 17 deletions src/ndarray/autograd.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,19 @@ class AutogradRuntime {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<mx_uint>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief record imperative operator which is executed by fcompute. */
void RecordImperativeFCompute(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray>* p_inputs,
std::vector<NDArray>* p_outputs);
/*! \brief record imperative operator which is executed by operator. */
void RecordImperativeOperator(const OpStatePtr& state,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray>* p_inputs,
std::vector<NDArray>* p_outputs);
/*! \brief find the input/output ndarrays that are needed for backward */
void GetBackwardDependency(
const nnvm::NodePtr& node,
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs);
/*! \brief to record operator, return corresponding node. */
void RecordOp(nnvm::NodeAttrs&& attrs,
std::vector<NDArray>* p_inputs,
std::vector<NDArray>* p_outputs,
const OpStatePtr& state = OpStatePtr(),
std::vector<bool>* p_save_inputs = nullptr,
std::vector<bool>* p_save_outputs = nullptr);
/*! \brief compute the gradient of outputs w.r.t variables. */
void ComputeGradient(const std::vector<NDArray>& outputs,
const std::vector<NDArray>& ograds,
Expand All @@ -126,12 +128,6 @@ class AutogradRuntime {
AutogradRuntime();

private:
/*! \brief to record operator, return corresponding node. */
void RecordOp(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray>* p_inputs,
std::vector<NDArray>* p_outputs,
const OpStatePtr& state);
/*! \brief AutogradRuntime singleton. */
static AutogradRuntime* instance_;
/*! \brief indicate whether is training. */
Expand Down
6 changes: 2 additions & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ NDArray NDArray::Reshape(const TShape &shape) const {
std::vector<NDArray> inputs, outputs;
inputs.emplace_back(*this);
outputs.emplace_back(std::move(ret));
AutogradRuntime::Get()->RecordImperativeFCompute(
op, attrs, &inputs, &outputs);
AutogradRuntime::Get()->RecordOp(std::move(attrs), &inputs, &outputs);
return outputs[0];
} else {
CHECK_GE(shape_.Size(), shape.Size())
Expand Down Expand Up @@ -115,8 +114,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
std::vector<NDArray> inputs, outputs;
inputs.emplace_back(*this);
outputs.emplace_back(std::move(ret));
AutogradRuntime::Get()->RecordImperativeFCompute(
op, attrs, &inputs, &outputs);
AutogradRuntime::Get()->RecordOp(std::move(attrs), &inputs, &outputs);
return outputs[0];
} else {
return ret;
Expand Down

0 comments on commit 2328fa9

Please sign in to comment.