Skip to content

Commit

Permalink
[MXNET-406] support init/pull dense weight, push row_sparse grad in k…
Browse files Browse the repository at this point in the history
…vstore (apache#10845)

* +sparse merged buff

* fix comm cpu

* fix bug

* fix ctx bug

* add kvstore cpu tests

* gpu tests

* add warning

* lint

* better err msg
  • Loading branch information
eric-haibin-lin authored May 11, 2018
1 parent ca08c1a commit 5569cc4
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 117 deletions.
167 changes: 120 additions & 47 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,41 +112,51 @@ class CommCPU : public Comm {

void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int type = mshadow::kFloat32) override {
if (stype == kDefaultStorage) {
merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type);
} else {
merge_buf_[key].merged = NDArray(stype, shape, pinned_ctx_, true, type);
}
// Delayed allocation - the dense merged buffer might not be used at all if push()
// only sees sparse arrays
bool delay_alloc = true;
merge_buf_[key].merged = NDArray(shape, pinned_ctx_, delay_alloc, type);
}

const NDArray& Reduce(int key, const std::vector<NDArray>& src,
int priority) override {
auto& buf = merge_buf_[key];
const auto stype = src[0].storage_type();
// avoid extra copy for single device, but it may bring problems for
// abnormal usage of kvstore
if (src.size() == 1) {
if (src[0].storage_type() == kDefaultStorage) {
if (stype == kDefaultStorage) {
return src[0];
} else { // if sparse and only one GPU, always update weight on CPU
CopyFromTo(src[0], &buf.merged, priority);
return buf.merged;
} else {
// With 'local' kvstore, we could store the weight on CPU while compute
// the gradient on GPU when the weight is extremely large.
// To avoiding copying the weight to the same context of the gradient,
// we always copy the gradient to merged buf.
NDArray& merged = buf.merged_buf(stype);
CopyFromTo(src[0], &merged, priority);
return merged;
}
}

if (buf.merged.storage_type() == kDefaultStorage) {
NDArray& buf_merged = buf.merged_buf(stype);
// normal dense reduce
if (stype == kDefaultStorage) {
std::vector<Engine::VarHandle> const_vars(src.size() - 1);
std::vector<NDArray> reduce(src.size());
CopyFromTo(src[0], &buf.merged, priority);
reduce[0] = buf.merged;
CopyFromTo(src[0], &buf_merged, priority);
reduce[0] = buf_merged;

if (buf.copy_buf.empty()) {
buf.copy_buf.resize(src.size()-1);
for (size_t j = 0; j < src.size() - 1; ++j) {
// allocate NDArray based on storage type
// allocate copy buffer
buf.copy_buf[j] = NDArray(
src[0].shape(), pinned_ctx_, false, src[0].dtype());
}
}
CHECK(stype == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << stype << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 1; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority);
reduce[i] = buf.copy_buf[i-1];
Expand All @@ -161,7 +171,7 @@ class CommCPU : public Comm {
FnProperty::kCPUPrioritized, priority, "KVStoreReduce");

} else {
// buf.merged is a sparse ndarray.
// sparse reduce
std::vector<Engine::VarHandle> const_vars(src.size());
std::vector<NDArray> reduce(src.size());

Expand All @@ -172,26 +182,28 @@ class CommCPU : public Comm {
src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype());
}
}
CHECK(stype == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << stype << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
const_vars[i] = reduce[i].var();
}
NDArray result = buf.merged;
Resource rsc = ResourceManager::Get()->Request(result.ctx(),
Resource rsc = ResourceManager::Get()->Request(buf_merged.ctx(),
ResourceRequest(ResourceRequest::kTempSpace));
Engine::Get()->PushAsync(
[reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray out = result;
[reduce, buf_merged, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray out = buf_merged;
is_serial_push_?
ReduceSumCPUExSerial(reduce, &out)
: mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
on_complete();
}, Context::CPU(), const_vars, {result.var(), rsc.var},
}, Context::CPU(), const_vars, {buf_merged.var(), rsc.var},
FnProperty::kCPUPrioritized, priority, "KVStoreReduce");
}

return buf.merged;
return buf_merged;
}

void Broadcast(int key, const NDArray& src,
Expand All @@ -200,10 +212,14 @@ class CommCPU : public Comm {
if (mask == Context::kCPU) {
for (auto d : dst) CopyFromTo(src, d, priority);
} else {
// first copy data to cpu, then broadcast
auto& buf = merge_buf_[key];
CopyFromTo(src, &buf.merged, priority);
for (auto d : dst) CopyFromTo(buf.merged, d, priority);
// First copy data to pinned_ctx, then broadcast.
// Note that kv.init initializes the data on pinned_ctx.
// This branch indicates push() with ndarrays on gpus were called,
// and the source is copied to gpu ctx.
// Also indicates that buffers are already initialized during push().
auto& buf = merge_buf_[key].merged_buf(src.storage_type());
CopyFromTo(src, &buf, priority);
for (auto d : dst) CopyFromTo(buf, d, priority);
}
}

Expand All @@ -228,7 +244,14 @@ class CommCPU : public Comm {
NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out :
NDArray(kRowSparseStorage, src.shape(), src.ctx(), true,
src.dtype(), src.aux_types());

if (!is_diff_var) {
common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) +
"refers to the same NDArray as the one stored in KVStore."
"Performing row_sparse_pull() with such output is going to change the "
"data stored in KVStore. Incorrect result may be generated "
"next time row_sparse_pull() is called. To avoid such an issue,"
"consider create a new NDArray buffer to store the output.");
}
Engine::Get()->PushAsync(
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
Expand Down Expand Up @@ -392,6 +415,24 @@ class CommCPU : public Comm {
NDArray merged;
/// \brief the cpu buffer for gpu data
std::vector<NDArray> copy_buf;
/// \brief the merged buffer for the given storage type
inline NDArray& merged_buf(NDArrayStorageType stype) {
if (stype == kDefaultStorage) {
return merged;
}
CHECK(stype == kRowSparseStorage) << "unexpected storage type " << stype;
// check if sparse_merged is initialized
if (sparse_merged.is_none()) {
CHECK(!merged.is_none());
sparse_merged = NDArray(kRowSparseStorage, merged.shape(), merged.ctx(),
true, merged.dtype());
}
return sparse_merged;
}

private:
/// \brief the sparse merged value
NDArray sparse_merged;
};
std::unordered_map<int, BufferEntry> merge_buf_;
size_t bigarray_bound_;
Expand All @@ -417,7 +458,7 @@ class CommDevice : public Comm {

void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
sorted_key_attrs_.emplace_back(key, shape, dtype, stype);
sorted_key_attrs_.emplace_back(key, shape, dtype);
}

void InitBuffersAndComm(const std::vector<NDArray>& src) {
Expand Down Expand Up @@ -451,10 +492,12 @@ class CommDevice : public Comm {
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());

const NDArrayStorageType stype = buf.merged.storage_type();
const NDArrayStorageType stype = src[0].storage_type();
NDArray& buf_merged = buf.merged_buf(stype);
// normal dense reduce
if (stype == kDefaultStorage) {
CopyFromTo(src[0], &(buf.merged), priority);
reduce[0] = buf.merged;
CopyFromTo(src[0], &buf_merged, priority);
reduce[0] = buf_merged;

if (buf.copy_buf.empty()) {
// TODO(mli) this results in large device memory usage for huge ndarray,
Expand All @@ -464,29 +507,32 @@ class CommDevice : public Comm {
buf.copy_buf.resize(src.size()-1);
for (size_t i = 0; i < src.size()-1; ++i) {
buf.copy_buf[i] = NDArray(
buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
buf_merged.shape(), buf_merged.ctx(), false, buf_merged.dtype());
}
}
for (size_t i = 0; i < src.size()-1; ++i) {
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
reduce[i+1] = buf.copy_buf[i];
}
} else {
// sparse reduce
if (buf.copy_buf.empty()) {
// initialize buffer for copying during reduce
buf.copy_buf.resize(src.size());
for (size_t j = 0; j < src.size(); ++j) {
buf.copy_buf[j] = NDArray(
buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
true, buf.merged.dtype());
buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
}
}
CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
}
}
ElementwiseSum(reduce, &buf.merged, priority);
return buf.merged;
ElementwiseSum(reduce, &buf_merged, priority);
return buf_merged;
}

const NDArray& ReduceCompressed(int key, const std::vector<NDArray>& src,
Expand Down Expand Up @@ -547,10 +593,10 @@ class CommDevice : public Comm {
}
}
} else {
auto& buf = merge_buf_[key];
CopyFromTo(src, &buf.merged, priority);
auto& buf_merged = merge_buf_[key].merged_buf(src.storage_type());
CopyFromTo(src, &buf_merged, priority);
for (auto d : dst) {
CopyFromTo(buf.merged, d, priority);
CopyFromTo(buf_merged, d, priority);
}
}
}
Expand All @@ -575,6 +621,14 @@ class CommDevice : public Comm {
NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out :
NDArray(kRowSparseStorage, out->shape(), src.ctx(), true,
out->dtype(), out->aux_types());
if (!is_diff_var) {
common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) +
"refers to the same NDArray as the one stored in KVStore."
"Performing row_sparse_pull() with such output is going to change the "
"data stored in KVStore. Incorrect result may be generated "
"next time row_sparse_pull() is called. To avoid such an issue,"
"consider create a new NDArray buffer to store the output.");
}

Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
Expand Down Expand Up @@ -647,7 +701,7 @@ class CommDevice : public Comm {
#endif
}

using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>;
using KeyAttrs = std::tuple<int, TShape, int>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
Expand All @@ -659,11 +713,11 @@ class CommDevice : public Comm {
for (auto d : devs) {
ctx_info[d.dev_id] = std::make_pair(d, 0);
}

for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
const int key = std::get<0>(sorted_key_attrs_[i]);
const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
const int type = std::get<2>(sorted_key_attrs_[i]);
const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
Expand All @@ -674,11 +728,10 @@ class CommDevice : public Comm {
min_size = size;
}
}
if (stype == kDefaultStorage) {
buf.merged = NDArray(shape, ctx, false, type);
} else {
buf.merged = NDArray(stype, shape, ctx, true, type);
}
// Delayed allocation - as the dense merged buffer might not be used at all if push()
// only sees sparse arrays
bool delay_alloc = true;
buf.merged = NDArray(shape, ctx, delay_alloc, type);
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
Expand All @@ -687,16 +740,36 @@ class CommDevice : public Comm {
std::vector<KeyAttrs> sorted_key_attrs_;
/// \brief temporal space for pushing and pulling
struct BufferEntry {
/// \brief the merged value
/// \brief the dense merged value for reduce and broadcast operations
NDArray merged;
/// \brief the gpu buffer
/// \brief the gpu buffer for copy during reduce operation
std::vector<NDArray> copy_buf;
/// \brief the residual buffer for gradient compression
std::vector<NDArray> residual;
/// \brief the small buffer for compressed data in sender
std::vector<NDArray> compressed_send_buf;
/// \brief the small buffer for compressed data in receiver
std::vector<NDArray> compressed_recv_buf;

/// \brief the merged buffer for the given storage type (could be either dense or row_sparse)
inline NDArray& merged_buf(NDArrayStorageType stype) {
if (stype == kDefaultStorage) {
CHECK(!merged.is_none()) << "unintialized merge buffer detected";
return merged;
}
CHECK(stype == kRowSparseStorage) << "unexpected storage type " << stype;
// check if sparse_merged is initialized
if (sparse_merged.is_none()) {
CHECK(!merged.is_none());
sparse_merged = NDArray(kRowSparseStorage, merged.shape(), merged.ctx(),
true, merged.dtype());
}
return sparse_merged;
}

private:
/// \brief the sparse merged value for reduce and rowsparse broadcast operations
NDArray sparse_merged;
};
std::unordered_map<int, BufferEntry> merge_buf_;
bool inited_;
Expand Down
4 changes: 2 additions & 2 deletions src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ class KVStoreLocal : public KVStore {
// invalid, print warning messages once
if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) {
LOG(INFO) << "Warning: non-default weights detected during kvstore pull. "
<< "This call has been ignored. "
<< "Please make sure to use row_sparse_pull with row_ids.";
"This call has been ignored. Please make sure to use"
"kv.row_sparse_pull() or module.prepare() with row_ids.";
this->warnings_printed_.insert(key);
}
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priori
CHECK_EQ(source[i].ctx().dev_mask(), Context::kCPU)
<< "operands context mismatch";
} else {
CHECK(source[i].ctx() == out->ctx())
CHECK_EQ(source[i].ctx(), out->ctx())
<< "operands context mismatch";
}
}
Expand Down
Loading

0 comments on commit 5569cc4

Please sign in to comment.