Skip to content

Commit

Permalink
[MXNET-424] dtype option for multinomial (apache#10970)
Browse files Browse the repository at this point in the history
* dtype option for multinomial

* Add missing test for uint8

* Add check to ensure dtype has a sufficient precision.

* Fix lint

* Error message for the dtype precision check

* Retrigger CI
  • Loading branch information
asitstands authored and piiswrong committed Jun 1, 2018
1 parent fe153ac commit a27b52e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 42 deletions.
7 changes: 5 additions & 2 deletions python/mxnet/ndarray/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N
[mu, alpha], shape, dtype, ctx, out, kwargs)


def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
"""Concurrent sampling from multiple multinomial distributions.
.. note:: The input distribution must be normalized, i.e. `data` must sum to
Expand All @@ -412,6 +412,9 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
reward as head gradient w.r.t. this array to estimate gradient.
out : NDArray
Store output to an existing NDArray.
dtype : str or numpy.dtype
Data type of the sample output array. The default is int32.
Note that the data type of the log likelihood array is the same with that of `data`.
Examples
--------
Expand All @@ -429,7 +432,7 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
[-1.20397282 -1.60943794]
<NDArray 2 @cpu(0)>
"""
return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs)
return _internal._sample_multinomial(data, shape, get_prob, out=out, dtype=dtype, **kwargs)


def shuffle(data, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/symbol/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa
[mu, alpha], shape, dtype, kwargs)


def multinomial(data, shape=_Null, get_prob=True, **kwargs):
def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
"""Concurrent sampling from multiple multinomial distributions.
.. note:: The input distribution must be normalized, i.e. `data` must sum to
Expand All @@ -245,8 +245,11 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs):
samples will also be returned.
This is usually used for reinforcement learning, where you can provide
reward as head gradient w.r.t. this array to estimate gradient.
dtype : str or numpy.dtype
Data type of the sample output array. The default is int32.
Note that the data type of the log likelihood array is the same with that of `data`.
"""
return _internal._sample_multinomial(data, shape, get_prob, **kwargs)
return _internal._sample_multinomial(data, shape, get_prob, dtype=dtype, **kwargs)


def shuffle(data, **kwargs):
Expand Down
16 changes: 16 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <thread>
#include <algorithm>
#include <functional>
#include <limits>

#include "../operator/mxnet_op.h"

Expand Down Expand Up @@ -617,6 +618,21 @@ FCompType GetFCompute(const nnvm::Op* op, const std::string& name,
}
}

/*!
* \brief Return the max integer value representable in the type `T` without loss of precision.
*/
template <typename T>
constexpr size_t MaxIntegerValue() {
return std::is_integral<T>::value ?
std::numeric_limits<T>::max():
size_t(2) << (std::numeric_limits<T>::digits - 1);
}

template <>
constexpr size_t MaxIntegerValue<mshadow::half::half_t>() {
return size_t(2) << 10;
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
3 changes: 2 additions & 1 deletion src/operator/random/sample_multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ struct SampleMultinomialBackwardCPUKernel {
DType* ograd, DType* dist, IType* out,
DType* igrad) {
for (index_t j = 0; j < M; ++j) {
igrad[i*K + out[i*M + j]] += ograd[i*M + j] / dist[i*K + out[i*M + j]];
igrad[i*K + static_cast<size_t>(out[i*M + j])] +=
ograd[i*M + j] / dist[i*K + static_cast<size_t>(out[i*M + j])];
}
}
};
Expand Down
3 changes: 2 additions & 1 deletion src/operator/random/sample_multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ struct SampleMultinomialBackwardGPUKernel {
DType* ograd, DType* dist, IType* out,
DType* igrad) {
for (index_t j = 0; j < M; ++j) {
atomicAdd(&igrad[i*K + out[i*M + j]], ograd[i*M + j] / dist[i*K + out[i*M + j]]);
atomicAdd(&igrad[i*K + static_cast<size_t>(out[i*M + j])],
ograd[i*M + j] / dist[i*K + static_cast<size_t>(out[i*M + j])]);
}
}
};
Expand Down
28 changes: 20 additions & 8 deletions src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
"result. This is usually used for differentiating through "
"stochastic variables, e.g. in reinforcement learning.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(mshadow::kInt32)
.describe("DType of the output in case this can't be inferred. "
"Only support int32 for now.");
.describe("DType of the output in case this can't be inferred.");
}
};

Expand All @@ -67,6 +70,11 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
const TShape& ishape = (*in_attrs)[0];
if (!ishape.ndim()) return false;

MSHADOW_TYPE_SWITCH(param.dtype, DType, {
CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>())
<< "'dtype' does not have a sufficient precision to represent the indices of the input array.";
});

if (ishape.ndim() == 1) {
if (param.shape.ndim()) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
Expand Down Expand Up @@ -155,9 +163,11 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> uniform =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M), s);
prnd->SampleUniform(&uniform, 0, 1);
Kernel<SampleMultinomialKernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
Kernel<SampleMultinomialKernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<IType>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
});
});
}

Expand All @@ -182,9 +192,11 @@ void SampleMultinomialBackward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
out = 0;
}
Kernel<kernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
inputs[2].dptr<int>(), outputs[0].dptr<DType>());
MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, {
Kernel<kernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
inputs[2].dptr<IType>(), outputs[0].dptr<DType>());
});
});
}

Expand Down
67 changes: 39 additions & 28 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,34 +377,45 @@ def test_parallel_random_seed_setting_for_context():

@with_seed()
def test_sample_multinomial():
for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]:
dx = mx.nd.ones_like(x)
mx.contrib.autograd.mark_variables([x], [dx])
# Adding rtol and increasing samples needed to pass with seed 2951820647
samples = 5000
with mx.autograd.record():
y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True)
r = prob * 5
r.backward()

y = y.asnumpy()
x = x.asnumpy()
dx = dx.asnumpy()
if len(x.shape) is 1:
x = x.reshape((1, x.shape[0]))
dx = dx.reshape(1, dx.shape[0])
y = y.reshape((1, y.shape[0]))
prob = prob.reshape((1, prob.shape[0]))
for i in range(x.shape[0]):
freq = np.bincount(y[i,:], minlength=5)/np.float32(samples)*x[i,:].sum()
mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
rprob = x[i][y[i]]/x[i].sum()
mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5)

real_dx = np.zeros((5,))
for j in range(samples):
real_dx[y[i][j]] += 5.0 / rprob[j]
mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5)
for dtype in ['uint8', 'int32', 'float16', 'float32', 'float64']: # output array types
for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]:
dx = mx.nd.ones_like(x)
mx.contrib.autograd.mark_variables([x], [dx])
# Adding rtol and increasing samples needed to pass with seed 2951820647
samples = 5000
with mx.autograd.record():
y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True, dtype=dtype)
r = prob * 5
r.backward()

assert(np.dtype(dtype) == y.dtype)
y = y.asnumpy()
x = x.asnumpy()
dx = dx.asnumpy()
if len(x.shape) is 1:
x = x.reshape((1, x.shape[0]))
dx = dx.reshape(1, dx.shape[0])
y = y.reshape((1, y.shape[0]))
prob = prob.reshape((1, prob.shape[0]))
for i in range(x.shape[0]):
freq = np.bincount(y[i,:].astype('int32'), minlength=5)/np.float32(samples)*x[i,:].sum()
mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
rprob = x[i][y[i].astype('int32')]/x[i].sum()
mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5)

real_dx = np.zeros((5,))
for j in range(samples):
real_dx[int(y[i][j])] += 5.0 / rprob[j]
mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5)
for dtype in ['uint8', 'float16', 'float32']:
# Bound check for the output data types. 'int32' and 'float64' require large memory so are skipped.
x = mx.nd.zeros(2 ** 25) # Larger than the max integer in float32 without precision loss.
bound_check = False
try:
y = mx.nd.random.multinomial(x, dtype=dtype)
except mx.MXNetError as e:
bound_check = True
assert bound_check

# Test the generators with the chi-square testing
@with_seed()
Expand Down

0 comments on commit a27b52e

Please sign in to comment.