Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-424] dtype option for multinomial #10970

Merged
merged 6 commits into from
Jun 1, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
dtype option for multinomial
  • Loading branch information
asitstands committed May 16, 2018
commit 9d191be38ec5fbefba94d12d2d9ca709f23ce504
7 changes: 5 additions & 2 deletions python/mxnet/ndarray/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N
[mu, alpha], shape, dtype, ctx, out, kwargs)


def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
"""Concurrent sampling from multiple multinomial distributions.

.. note:: The input distribution must be normalized, i.e. `data` must sum to
Expand All @@ -412,6 +412,9 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
reward as head gradient w.r.t. this array to estimate gradient.
out : NDArray
Store output to an existing NDArray.
dtype : str or numpy.dtype
Data type of the sample output array. The default is int32.
Note that the data type of the log likelihood array is the same with that of `data`.

Examples
--------
Expand All @@ -429,7 +432,7 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
[-1.20397282 -1.60943794]
<NDArray 2 @cpu(0)>
"""
return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs)
return _internal._sample_multinomial(data, shape, get_prob, out=out, dtype=dtype, **kwargs)


def shuffle(data, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/symbol/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa
[mu, alpha], shape, dtype, kwargs)


def multinomial(data, shape=_Null, get_prob=True, **kwargs):
def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
"""Concurrent sampling from multiple multinomial distributions.

.. note:: The input distribution must be normalized, i.e. `data` must sum to
Expand All @@ -245,8 +245,11 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs):
samples will also be returned.
This is usually used for reinforcement learning, where you can provide
reward as head gradient w.r.t. this array to estimate gradient.
dtype : str or numpy.dtype
Data type of the sample output array. The default is int32.
Note that the data type of the log likelihood array is the same with that of `data`.
"""
return _internal._sample_multinomial(data, shape, get_prob, **kwargs)
return _internal._sample_multinomial(data, shape, get_prob, dtype=dtype, **kwargs)


def shuffle(data, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion src/operator/random/sample_multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ struct SampleMultinomialBackwardCPUKernel {
DType* ograd, DType* dist, IType* out,
DType* igrad) {
for (index_t j = 0; j < M; ++j) {
igrad[i*K + out[i*M + j]] += ograd[i*M + j] / dist[i*K + out[i*M + j]];
igrad[i*K + static_cast<size_t>(out[i*M + j])] +=
ograd[i*M + j] / dist[i*K + static_cast<size_t>(out[i*M + j])];
}
}
};
Expand Down
3 changes: 2 additions & 1 deletion src/operator/random/sample_multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ struct SampleMultinomialBackwardGPUKernel {
DType* ograd, DType* dist, IType* out,
DType* igrad) {
for (index_t j = 0; j < M; ++j) {
atomicAdd(&igrad[i*K + out[i*M + j]], ograd[i*M + j] / dist[i*K + out[i*M + j]]);
atomicAdd(&igrad[i*K + static_cast<size_t>(out[i*M + j])],
ograd[i*M + j] / dist[i*K + static_cast<size_t>(out[i*M + j])]);
}
}
};
Expand Down
23 changes: 15 additions & 8 deletions src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
"result. This is usually used for differentiating through "
"stochastic variables, e.g. in reinforcement learning.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(mshadow::kInt32)
.describe("DType of the output in case this can't be inferred. "
"Only support int32 for now.");
.describe("DType of the output in case this can't be inferred.");
}
};

Expand Down Expand Up @@ -155,9 +158,11 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> uniform =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M), s);
prnd->SampleUniform(&uniform, 0, 1);
Kernel<SampleMultinomialKernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of 2 layer switches is very slow to compile. Why do we need type support for output?

Copy link
Contributor Author

@asitstands asitstands May 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes the multinomial samples need further processing in floating point arithmetic, so the samples need to be copied into a new array of floating point type. The copy slows down the training. For example, in RBM, the samples need to be applied by linalg.gemm which supports only floating point arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple cast shouldn't cost that much?

This kind of nested switches are really slow to compile and makes the binary much bigger.
We need to make sure it really justifies the cost

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The binary size increases about 0.1% for both shared and static library (CUDA, CUDNN, MKL). Compiling mxnet already takes quite long time, so the relative increase of the compile time is also tiny.

I'm working with some variants of RBM and the use of .astype('float32') in several places increases training time over 20%. In the case of usual basic RBM, it increases about 10% of training time in my test for mnist. Of course, it depends on the hyperparameters and data. However, I think that, in general, the cost cannot be ignored for applications using heavy Monte Carlo samplings of discrete states.

Kernel<SampleMultinomialKernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<IType>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
});
});
}

Expand All @@ -182,9 +187,11 @@ void SampleMultinomialBackward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
out = 0;
}
Kernel<kernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
inputs[2].dptr<int>(), outputs[0].dptr<DType>());
MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, {
Kernel<kernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
inputs[2].dptr<IType>(), outputs[0].dptr<DType>());
});
});
}

Expand Down
58 changes: 30 additions & 28 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,34 +377,36 @@ def test_parallel_random_seed_setting_for_context():

@with_seed()
def test_sample_multinomial():
for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]:
dx = mx.nd.ones_like(x)
mx.contrib.autograd.mark_variables([x], [dx])
# Adding rtol and increasing samples needed to pass with seed 2951820647
samples = 5000
with mx.autograd.record():
y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True)
r = prob * 5
r.backward()

y = y.asnumpy()
x = x.asnumpy()
dx = dx.asnumpy()
if len(x.shape) is 1:
x = x.reshape((1, x.shape[0]))
dx = dx.reshape(1, dx.shape[0])
y = y.reshape((1, y.shape[0]))
prob = prob.reshape((1, prob.shape[0]))
for i in range(x.shape[0]):
freq = np.bincount(y[i,:], minlength=5)/np.float32(samples)*x[i,:].sum()
mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
rprob = x[i][y[i]]/x[i].sum()
mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5)

real_dx = np.zeros((5,))
for j in range(samples):
real_dx[y[i][j]] += 5.0 / rprob[j]
mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5)
for dtype in ['int32', 'float16', 'float32', 'float64']: # output array types
for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]:
dx = mx.nd.ones_like(x)
mx.contrib.autograd.mark_variables([x], [dx])
# Adding rtol and increasing samples needed to pass with seed 2951820647
samples = 5000
with mx.autograd.record():
y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True, dtype=dtype)
r = prob * 5
r.backward()

assert(np.dtype(dtype) == y.dtype)
y = y.asnumpy()
x = x.asnumpy()
dx = dx.asnumpy()
if len(x.shape) is 1:
x = x.reshape((1, x.shape[0]))
dx = dx.reshape(1, dx.shape[0])
y = y.reshape((1, y.shape[0]))
prob = prob.reshape((1, prob.shape[0]))
for i in range(x.shape[0]):
freq = np.bincount(y[i,:].astype('int32'), minlength=5)/np.float32(samples)*x[i,:].sum()
mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
rprob = x[i][y[i].astype('int32')]/x[i].sum()
mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5)

real_dx = np.zeros((5,))
for j in range(samples):
real_dx[int(y[i][j])] += 5.0 / rprob[j]
mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5)

# Test the generators with the chi-square testing
@with_seed()
Expand Down