Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Test_take, add additional axis (#20532)
Browse files Browse the repository at this point in the history
  • Loading branch information
mozga-intel authored Aug 19, 2021
1 parent 3ae7ddf commit 9a4dcf4
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4188,47 +4188,47 @@ def grad_helper(grad_in, axis, idx):
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=1, high=5), )

data = mx.sym.Variable('a')
idx = mx.sym.Variable('indices')
idx = mx.sym.BlockGrad(idx)
result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
exe = result._simple_bind(default_context(), a=data_shape,
indices=idx_shape)
data_real = np.random.normal(size=data_shape).astype('float32')
if out_of_range:
idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
if mode == 'raise':
idx_real[idx_real == 0] = 1
idx_real *= data_shape[axis]
else:
idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
if axis < 0:
axis += len(data_shape)
data = mx.sym.Variable('a')
idx = mx.sym.Variable('indices')
idx = mx.sym.BlockGrad(idx)
result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
exe = result._simple_bind(default_context(), a=data_shape,
indices=idx_shape)
data_real = np.random.normal(size=data_shape).astype('float32')
if out_of_range:
idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
if mode == 'raise':
idx_real[idx_real == 0] = 1
idx_real *= data_shape[axis]
else:
idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
if axis < 0:
axis += len(data_shape)

grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32')
grad_in = np.zeros(data_shape, dtype='float32')
grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32')
grad_in = np.zeros(data_shape, dtype='float32')

exe.arg_dict['a'][:] = mx.nd.array(data_real)
exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
exe.forward(is_train=True)
if out_of_range and mode == 'raise':
try:
mx_out = exe.outputs[0].asnumpy()
except MXNetError as e:
return
else:
# Did not raise exception
assert False, "did not raise %s" % MXNetError.__name__
exe.arg_dict['a'][:] = mx.nd.array(data_real)
exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
exe.forward(is_train=True)
if out_of_range and mode == 'raise':
try:
mx_out = exe.outputs[0].asnumpy()
except MXNetError as e:
return
else:
# Did not raise exception
assert False, "did not raise %s" % MXNetError.__name__

assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode))
assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode))

for i in np.nditer(idx_real):
if mode == 'clip':
i = np.clip(i, 0, data_shape[axis])
grad_helper(grad_in, axis, i)
for i in np.nditer(idx_real):
if mode == 'clip':
i = np.clip(i, 0, data_shape[axis])
grad_helper(grad_in, axis, i)

exe.backward([mx.nd.array(grad_out)])
assert_almost_equal(exe.grad_dict['a'], grad_in)
exe.backward([mx.nd.array(grad_out)])
assert_almost_equal(exe.grad_dict['a'], grad_in)


def test_grid_generator():
Expand Down

0 comments on commit 9a4dcf4

Please sign in to comment.