Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed May 1, 2018
1 parent ca7e4a9 commit c84d427
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False):
check_rsp_push_pull('device')
check_rsp_push_pull('device', is_push_cpu=False)


def test_kvstore_row_sparse_pull_single_device():
kvstore = mx.kv.create('device')
grad = mx.nd.random_normal(shape=(10,30),ctx=mx.gpu(0))
grad = grad.tostype("row_sparse")
copy = grad.copy()

i = 0
kvstore.init(i, grad)
kvstore.push(i, grad)
kvstore.row_sparse_pull(i, out=grad, row_ids=grad.indices)

assert_almost_equal(grad.asnumpy(), copy.asnumpy())


def test_rsp_push_pull_large_rowid():
num_rows = 793470
val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())
Expand Down

0 comments on commit c84d427

Please sign in to comment.