Skip to content

Commit

Permalink
Fix arange (#8268)
Browse files Browse the repository at this point in the history
* fix arange

* fix

* trigger test
  • Loading branch information
piiswrong committed Oct 14, 2017
1 parent 46ec178 commit ffa6e45
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 35 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,11 +1869,11 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
Parameters
----------
start : float, optional
start : number, optional
Start of interval. The default start value is 0.
stop : float
stop : number
End of interval.
step : float, optional
step : number, optional
Spacing between values. The default step size is 1.
repeat : int, optional
Number of times to repeat each element. The default repeat count is 1.
Expand Down
43 changes: 24 additions & 19 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ struct InitOpParam : public dmlc::Parameter<InitOpParam> {
};

struct RangeParam : public dmlc::Parameter<RangeParam> {
real_t start;
dmlc::optional<real_t> stop;
real_t step;
double start;
dmlc::optional<double> stop;
double step;
int repeat;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
DMLC_DECLARE_FIELD(start)
.describe("Start of interval. The interval includes this value. The default start value is 0.");
DMLC_DECLARE_FIELD(stop)
.set_default(dmlc::optional<real_t>())
.set_default(dmlc::optional<double>())
.describe("End of interval. The interval does not include this value,"
" except in some cases where step is not an integer and"
" floating point round-off affects the length of out.");
Expand Down Expand Up @@ -281,22 +281,27 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
}
}

struct range_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int repeat, DType start, DType step,
int req, DType* out) {
KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step);
}
};

template<typename xpu>
void RangeCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(out, req[0], range<DType>(param.start,
param.stop.value(),
param.step,
param.repeat));
Kernel<range_fwd, xpu>::Launch(s, outputs[0].Size(),
static_cast<int>(param.repeat), static_cast<DType>(param.start),
static_cast<DType>(param.step), req[0], outputs[0].dptr<DType>());
});
}

Expand All @@ -307,24 +312,24 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_NE(param.step, 0U)
CHECK_NE(param.step, 0)
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Range does not support (start, stop, step) = "
<< "Invalid range (start, stop, step) = "
<< "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
} else {
CHECK(param.start > param.stop.value())
<< "Range does not support (start, stop, step)= "
<< "Invalid range (start, stop, step)= "
<< "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
mshadow::Shape1(mshadow::expr::RangeOutSize(param.start,
param.stop.value(),
param.step,
param.repeat)));
MSHADOW_TYPE_SWITCH(param.dtype, DType, {
double out_size = std::ceil((param.stop.value() - param.start) / param.step)
* param.repeat;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({static_cast<nnvm::dim_t>(out_size)}));
});
return true;
}

Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ def test_arange():
gt = np.broadcast_to(gt.reshape((gt.shape[0], 1)), shape=(gt.shape[0], repeat)).ravel()
pred = mx.nd.arange(start=start, stop=stop, step=step, repeat=repeat).asnumpy()
assert_almost_equal(pred, gt)
gt = np.arange(start=0, stop=10000**2, step=10001, dtype=np.int32)
pred = mx.nd.arange(start=0, stop=10000**2, step=10001,
dtype="int32").asnumpy()
assert_almost_equal(pred, gt)

def test_order(ctx=default_context()):
def gt_topk(dat, axis, ret_typ, k, is_ascend):
Expand Down
27 changes: 14 additions & 13 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,18 +2508,20 @@ def test_basic_val_init(sym_func, np_func, shape, dtype):
assert exe.outputs[0].asnumpy().dtype == dtype

def test_arange():
for i in range(5):
start = np.random.rand() * 10
stop = start + np.random.rand() * 100
step = np.random.rand() * 4
repeat = int(np.random.rand() * 5) + 1
gt = np.arange(start=start, stop=stop, step=step)
gt = np.broadcast_to(gt.reshape((gt.shape[0], 1)), shape=(gt.shape[0], repeat)).ravel()
x = mx.sym.arange(start=start, stop=stop, step=step, repeat=repeat)
exe = x.simple_bind(ctx=default_context())
assert len(exe.grad_arrays) == 0
pred = exe.forward(is_train=False)[0].asnumpy()
assert_almost_equal(pred, gt)
# General Random Tests
dtype_list = [np.float32, np.float64, np.int32, np.uint8]
config_list = [(10,),
(0, 10),
(5, 100, 4),
(50, -50, -2),
(1.3, 456.6, 1.3)]
for dtype in dtype_list:
for config in config_list:
repeats = random.choice([1, 3])
np_out = np.repeat(np.arange(*config, dtype=dtype), repeats)
nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype)
assert_almost_equal(np_out, nd_out.asnumpy())

test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32)
test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32)
test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16)
Expand Down Expand Up @@ -4266,7 +4268,6 @@ def check(data, idx):

assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit ffa6e45

Please sign in to comment.