Skip to content

Commit

Permalink
Autograd with multiple devices (#8124)
Browse files Browse the repository at this point in the history
* add cross device autograd

* place device
  • Loading branch information
piiswrong committed Oct 2, 2017
1 parent 44d72e7 commit 1637011
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 155 deletions.
78 changes: 39 additions & 39 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,45 @@ namespace mxnet {
/*! \brief runtime functions for NDArray */
class Imperative {
public:
/*! \brief */
class AGInfo {
public:
Context ctx;
OpReqType grad_req;
OpStatePtr state;
std::vector<NDArray> outputs;
std::vector<NDArray> out_grads;
bool fresh_out_grad;

AGInfo() :
grad_req(kNullOp), fresh_out_grad(false) {}

static void Clear(const nnvm::NodePtr& node) {
if (node == nullptr || node->info.empty()) return;
AGInfo& info = Get(node);
if (info.grad_req != kNullOp) return;
node->info.clear();
}

static AGInfo& Get(const nnvm::NodePtr& node) {
return dmlc::get<AGInfo>(node->info);
}

static AGInfo& Create(const nnvm::NodePtr& node) {
node->info.construct<AGInfo>();
return Get(node);
}

static bool IsNone(const NDArray& arr) {
return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
}

static bool IsVariable(const nnvm::NodePtr& node) {
AGInfo& info = Get(node);
return info.grad_req != kNullOp && info.outputs.size() == 1
&& info.out_grads.size() == 1;
}
};
class CachedOp {
public:
explicit CachedOp(const nnvm::Symbol& sym);
Expand Down Expand Up @@ -141,44 +180,6 @@ class Imperative {

private:
friend class NDArray;
/*! \brief */
class AGInfo {
public:
OpReqType grad_req;
OpStatePtr state;
std::vector<NDArray> outputs;
std::vector<NDArray> out_grads;
bool fresh_out_grad;

AGInfo() :
grad_req(kNullOp), fresh_out_grad(false) {}

static void Clear(const nnvm::NodePtr& node) {
if (node == nullptr || node->info.empty()) return;
AGInfo& info = Get(node);
if (info.grad_req != kNullOp) return;
node->info.clear();
}

static AGInfo& Get(const nnvm::NodePtr& node) {
return dmlc::get<AGInfo>(node->info);
}

static AGInfo& Create(const nnvm::NodePtr& node) {
node->info.construct<AGInfo>();
return Get(node);
}

static bool IsNone(const NDArray& arr) {
return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
}

static bool IsVariable(const nnvm::NodePtr& node) {
AGInfo& info = Get(node);
return info.grad_req != kNullOp && info.outputs.size() == 1
&& info.out_grads.size() == 1;
}
};
/*! \brief make constructor protected. */
Imperative() {}
/*! \brief find the input/output ndarrays that are needed for backward */
Expand All @@ -189,7 +190,6 @@ class Imperative {
std::vector<bool> *p_save_outputs);
void RunGraph(
const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,19 @@ size_t num_aux_data(NDArrayStorageType stype);
*/
void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0);

/*!
* \brief issue an copy operation from one NDArray to another
* the two ndarray can sit on different devices
* this operation will be scheduled by the engine
*
* \param from the ndarray we want to copy data from
* \param to the target ndarray
* \param priority Priority of the action.
* \note The function name explicitly marks the order of from and to
* due to different possible convention carried by copy function.
*/
void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0);

/*!
* \brief Perform elementwise sum over each data from source, store result into out.
* \param source the ndarray we want to sum
Expand Down
10 changes: 6 additions & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
bool match = true;
match &= CheckAndInferShape(&g, std::move(shape_inputs), true);
match &= CheckAndInferType(&g, std::move(dtype_inputs), true);
match &= CheckAndInferStorageType(&g, inputs[0]->ctx().dev_mask(),
exec::DevMaskVector dev_mask(g.indexed_graph().num_nodes(), inputs[0]->ctx().dev_mask());
match &= CheckAndInferStorageType(&g, std::move(dev_mask),
std::move(storage_type_inputs), true);

if (!match) {
Expand Down Expand Up @@ -282,7 +283,8 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
node_range, entry_range);
match &= CheckAndInferType(&g, std::move(dtypes), false,
node_range, entry_range);
match &= CheckAndInferStorageType(&g, inputs[0]->ctx().dev_mask(), std::move(stypes),
exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask());
match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes),
false, node_range, entry_range);

if (!match) {
Expand Down Expand Up @@ -352,7 +354,7 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
Imperative::Get()->RunGraph(
false, default_ctx, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes);

for (size_t i = 0; i < idx.num_node_entries(); ++i) {
Expand Down Expand Up @@ -422,7 +424,7 @@ void Imperative::CachedOp::Backward(

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
Imperative::Get()->RunGraph(
retain_graph, default_ctx, idx, arrays, num_forward_nodes, idx.num_nodes(),
retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);

if (retain_graph) {
Expand Down
53 changes: 33 additions & 20 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,13 @@ void Imperative::MarkVariables(
info.outputs.emplace_back(variables[i]->Detach());
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();

gradients[i]->entry_ = nnvm::NodeEntry{
nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
AGInfo& grad_info = AGInfo::Create(gradients[i]->entry_.node);
grad_info.outputs.emplace_back(gradients[i]->Detach());
grad_info.ctx = gradients[i]->ctx();
}
}

Expand Down Expand Up @@ -207,6 +209,7 @@ void Imperative::RecordOp(
node->attrs.name = "node_" + std::to_string(node_count_++);
AGInfo& info = AGInfo::Create(node);
info.state = state;
info.ctx = outputs[0]->ctx();

if (p_save_inputs == nullptr) {
p_save_inputs = &(local_buff->save_inputs);
Expand All @@ -225,6 +228,7 @@ void Imperative::RecordOp(
nnvm::NodeEntry entry{nnvm::Symbol::CreateVariable(
"null" + std::to_string(variable_count_++)).outputs[0].node, 0, 0};
AGInfo& input_info = AGInfo::Create(entry.node);
input_info.ctx = inputs[i]->ctx();
if (save_inputs[i]) {
input_info.outputs.emplace_back(*inputs[i]);
} else {
Expand Down Expand Up @@ -263,7 +267,6 @@ void Imperative::RecordOp(

void Imperative::RunGraph(
const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
Expand All @@ -288,22 +291,24 @@ void Imperative::RunGraph(
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) continue;
auto num_outputs = node.source->num_outputs();
ndinputs.clear();
ndinputs.reserve(node.inputs.size());
for (const auto& j : node.inputs) {
ndinputs.emplace_back(arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
}
ndoutputs.clear();
ndoutputs.reserve(node.source->num_outputs());
ndoutputs.reserve(num_outputs);
req.clear();
req.reserve(node.source->num_outputs());
for (size_t j = 0; j < node.source->num_outputs(); ++j) {
req.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(i, j);
ndoutputs.emplace_back(arrays[eid]);
req.push_back(array_reqs[eid]);
CHECK(!ndoutputs.back()->is_none());
}
const Context& ctx = ndoutputs[0]->ctx();
const DispatchMode dispatch_mode = dispatch_modes[i];
if (node.source->op() == bwd_cached_op) {
const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
Expand All @@ -320,19 +325,19 @@ void Imperative::RunGraph(
arg_dtypes.emplace_back(ndinputs[i]->dtype());
}
states[i] = createop[node.source->op()](
node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
InvokeOp(default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
node.source->attrs, ctx, arg_shapes, arg_dtypes);
InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
} else if (is_layer_backward.get(node.source->op(), false)) {
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
InvokeOp(default_ctx, node.source->attrs, ndinputs, ndoutputs,
InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
req, dispatch_mode, states[fwd_node_id]);
if (recording) {
RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
}
} else {
InvokeOp(default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
}

Expand Down Expand Up @@ -378,6 +383,7 @@ std::vector<NDArray*> Imperative::Backward(
for (size_t i = 0; i < outputs.size(); ++i) {
ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
AGInfo& info = AGInfo::Create(ograd_entries.back().node);
info.ctx = outputs[i]->ctx();
if (ograds[i] != nullptr) {
info.outputs.emplace_back(*ograds[i]);
} else {
Expand Down Expand Up @@ -495,7 +501,7 @@ std::vector<NDArray*> Imperative::Backward(
}

// Assign context
Context default_ctx = outputs[0]->ctx();
auto vctx = PlaceDevice(idx);

// Infer shape type
{
Expand All @@ -518,9 +524,11 @@ std::vector<NDArray*> Imperative::Backward(
StorageTypeVector stypes;
stypes.reserve(idx.num_node_entries());
for (const auto& i : arrays) stypes.emplace_back(i->storage_type());
CheckAndInferStorageType(
&graph, default_ctx.dev_mask(), std::move(stypes), false,
node_range, entry_range);
exec::DevMaskVector dev_mask;
dev_mask.reserve(idx.num_nodes());
for (const auto& i : vctx) dev_mask.emplace_back(i.dev_mask());
CheckAndInferStorageType(&graph, std::move(dev_mask), std::move(stypes), false,
node_range, entry_range);
}

// Calculate ref count
Expand All @@ -544,13 +552,18 @@ std::vector<NDArray*> Imperative::Backward(
const auto& dtypes = graph.GetAttr<DTypeVector>("dtype");
const auto& stypes = graph.GetAttr<StorageTypeVector>("storage_type");
const auto& dispatch_modes = graph.GetAttr<DispatchModeVector>("dispatch_mode");
for (size_t i = num_forward_entries; i < arrays.size(); ++i) {
if (!arrays[i]->is_none()) continue;
if (stypes[i] == kDefaultStorage) {
*arrays[i] = NDArray(shapes[i], default_ctx, true, dtypes[i]);
} else {
*arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
shapes[i], default_ctx, true, dtypes[i]);

for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
auto num_outputs = idx[i].source->num_outputs();
for (size_t j = 0; j < num_outputs; ++j) {
auto eid = idx.entry_id(i, j);
if (!arrays[eid]->is_none()) continue;
if (stypes[eid] == kDefaultStorage) {
*arrays[eid] = NDArray(shapes[eid], vctx[i], true, dtypes[eid]);
} else {
*arrays[eid] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], vctx[i], true, dtypes[eid]);
}
}
}

Expand All @@ -559,7 +572,7 @@ std::vector<NDArray*> Imperative::Backward(
bool prev_recording = set_is_recording(create_graph);
bool prev_training = set_is_training(is_train);

RunGraph(retain_graph, default_ctx, idx, arrays, num_forward_nodes, idx.num_nodes(),
RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);

set_is_recording(prev_recording);
Expand Down
Loading

0 comments on commit 1637011

Please sign in to comment.