Skip to content

Commit

Permalink
Fix row_sparse_pull with single gpu (apache#10772)
Browse files Browse the repository at this point in the history
* Fix row_sparse_pull with single gpu

* Add test

* Fix row_sparse_pull with single gpu

* Add test

* fix sparse retain in comm.h

* remove dedup var

* update test
  • Loading branch information
leezu authored and anirudh2290 committed May 7, 2018
1 parent ef1991d commit 1d053b9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,12 @@ class CommCPU : public Comm {
CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
<< "BroadcastRowSparse with row_indices on gpu context not supported";
// retain according to unique indices
const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage, src.shape(),
src.ctx(), true, src.dtype(), src.aux_types()) : *out;
const bool is_same_ctx = out->ctx() == src.ctx();
const bool is_diff_var = out->var() != src.var();
NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out :
NDArray(kRowSparseStorage, src.shape(), src.ctx(), true,
src.dtype(), src.aux_types());

Engine::Get()->PushAsync(
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
Expand Down Expand Up @@ -565,14 +568,18 @@ class CommDevice : public Comm {
<< "BroadcastRowSparse expects row_sparse dst NDArray";
CHECK_EQ(row_id.ctx(), src.ctx())
<< "row_id and src are expected to be on the same context";

// retain according to indices
const bool is_diff_ctx = out->ctx() != src.ctx();
NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
src.ctx(), true, out->dtype(), out->aux_types()) : *out;
const bool is_same_ctx = out->ctx() == src.ctx();
const bool is_diff_var = out->var() != src.var();
NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out :
NDArray(kRowSparseStorage, out->shape(), src.ctx(), true,
out->dtype(), out->aux_types());

Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
using namespace mxnet::common;
NDArray temp = out_gpu;
NDArray temp = retained_gpu;
switch (temp.ctx().dev_mask()) {
case cpu::kDevMask: {
SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
Expand All @@ -591,9 +598,9 @@ class CommDevice : public Comm {
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
on_complete();
}, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
}, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()},
FnProperty::kNormal, priority, "KVStoreSparseRetain");
CopyFromTo(out_gpu, out, priority);
CopyFromTo(retained_gpu, out, priority);
}
}

Expand Down
15 changes: 15 additions & 0 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False):
check_rsp_push_pull('device')
check_rsp_push_pull('device', is_push_cpu=False)


def test_row_sparse_pull_single_device():
kvstore = mx.kv.create('device')
copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0))
grad = copy.tostype("row_sparse")

key = 0
kvstore.init(key, grad)
idx = grad.indices
kvstore.push(key, grad)
kvstore.row_sparse_pull(key, out=grad, row_ids=idx)

assert_almost_equal(grad.asnumpy(), copy.asnumpy())


def test_rsp_push_pull_large_rowid():
num_rows = 793470
val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())
Expand Down

0 comments on commit 1d053b9

Please sign in to comment.