diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index 8998c0be..3c3b85d3 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -240,10 +240,10 @@ inline void NDArray::WaitToWrite() { } inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); } inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) { - Operator("_sample_normal")(mu, sigma).Invoke(*out); + Operator("_random_normal")(mu, sigma).Invoke(*out); } inline void NDArray::SampleUniform(mx_float begin, mx_float end, NDArray *out) { - Operator("_sample_uniform")(begin, end).Invoke(*out); + Operator("_random_uniform")(begin, end).Invoke(*out); } inline void NDArray::Load(const std::string &file_name, std::vector *array_list, diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm index 9ca013c6..8f7b6d3f 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm @@ -58,13 +58,13 @@ method seed(Int $seed_state) } for my $method ( - [qw/_sample_uniform uniform/], - [qw/_sample_normal normal/], - [qw/_sample_gamma gamma/], - [qw/_sample_exponential exponential/], - [qw/_sample_poisson poisson/], - [qw/_sample_negbinomial negative_binomial/], - [qw/_sample_gennegbinomial generalized_negative_binomial/], + [qw/_random_uniform uniform/], + [qw/_random_normal normal/], + [qw/_random_gamma gamma/], + [qw/_random_exponential exponential/], + [qw/_random_poisson poisson/], + [qw/_random_negbinomial negative_binomial/], + [qw/_random_gennegbinomial generalized_negative_binomial/], ) { my ($nd_method_name, $rnd_method_name) = @{$method}; diff --git a/python/mxnet/base.py b/python/mxnet/base.py index fc07853b..fe80bd69 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -365,7 +365,7 @@ def _as_list(obj): return [obj] -_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_random_', '_sparse_'] +_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_'] def _get_op_name_prefix(op_name): diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index 43ec961a..fdd4375d 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -17,7 +17,7 @@ """NDArray API of MXNet.""" -from . import _internal, contrib, linalg, random, sparse +from . import _internal, contrib, linalg, sparse, random # pylint: disable=wildcard-import, redefined-builtin from .op import * from .ndarray import * diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index 0ec4578b..19cab734 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -16,4 +16,428 @@ # under the License. """Random distribution generator NDArray API of MXNet.""" -__all__ = [] + +from ..base import numeric_types, _Null +from ..context import current_context +from . import _internal +from .ndarray import NDArray + + +__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial', + 'negative_binomial', 'generalized_negative_binomial'] + + +def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs): + """Helper function for random generators.""" + if isinstance(params[0], NDArray): + for i in params[1:]: + assert isinstance(i, NDArray), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s."%(type(params[0]), type(i)) + return sampler(*params, shape=shape, dtype=dtype, out=out, **kwargs) + elif isinstance(params[0], numeric_types): + if ctx is None: + ctx = current_context() + if shape is _Null and out is None: + shape = 1 + for i in params[1:]: + assert isinstance(i, numeric_types), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s."%(type(params[0]), type(i)) + return random(*params, shape=shape, dtype=dtype, ctx=ctx, out=out, **kwargs) + + raise ValueError("Distribution parameters must be either NDArray or numbers, " + "but got %s."%type(params[0])) + + +def uniform(low=0, high=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): + """Draw random samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval *[low, high)* + (includes *low*, but excludes *high*). + + Parameters + ---------- + low : float or NDArray + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float or NDArray + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and + `high` are scalars, output shape will be `(m, n)`. If `low` and `high` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + ctx : Context + Device context of output. Default is current context. Overridden by + `low.context` when `low` is an NDArray. + out : NDArray + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.uniform(0, 1) + [ 0.54881352] + >>> mx.nd.random.uniform(0, 1, ctx=mx.gpu(0)) + [ 0.92514056] + + >>> mx.nd.random.uniform(-1, 1, shape=(2,)) + [[ 0.71589124 0.08976638] + [ 0.69450343 -0.15269041]] + + >>> low = mx.nd.array([1,2,3]) + >>> high = mx.nd.array([2,3,4]) + >>> mx.nd.random.uniform(low, high, shape=2) + [[ 1.78653979 1.93707538] + [ 2.01311183 2.37081361] + [ 3.30491424 3.69977832]] + + """ + return _random_helper(_internal._random_uniform, _internal._sample_uniform, + [low, high], shape, dtype, ctx, out, kwargs) + + +def normal(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float or NDArray + Mean (centre) of the distribution. + scale : float or NDArray + Standard deviation (spread or width) of the distribution. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and + `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + ctx : Context + Device context of output. Default is current context. Overridden by + `loc.context` when `loc` is an NDArray. + out : NDArray + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.normal(0, 1) + [ 2.21220636] + + >>>> mx.nd.random.normal(0, 1, ctx=mx.gpu(0)) + [ 0.29253659] + + >>> mx.nd.random.normal(-1, 1, shape=(2,)) + [-0.2259962 -0.51619542] + + >>> loc = mx.nd.array([1,2,3]) + >>> scale = mx.nd.array([2,3,4]) + >>> mx.nd.random.normal(loc, scale, shape=2) + [[ 0.55912292 3.19566321] + [ 1.91728961 2.47706747] + [ 2.79666662 5.44254589]] + + """ + return _random_helper(_internal._random_normal, _internal._sample_normal, + [loc, scale], shape, dtype, ctx, out, kwargs) + + +def poisson(lam=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): + """Draw random samples from a Poisson distribution. + + Samples are distributed according to a Poisson distribution parametrized + by *lambda* (rate). Samples will always be returned as a floating point data type. + + .. note:: poisson is not implemented for GPU yet. + + Parameters + ---------- + lam : float or NDArray + Expectation of interval, should be >= 0. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `lam` is + a scalar, output shape will be `(m, n)`. If `lam` + is an NDArray with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `lam`. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + ctx : Context + Device context of output. Default is current context. Overridden by + `lam.context` when `lam` is an NDArray. + out : NDArray + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.poisson(1) + [ 1.] + + >>> mx.nd.random.poisson(1, shape=(2,)) + [ 0. 2.] + + >>> lam = mx.nd.array([1,2,3]) + >>> mx.nd.random.poisson(lam, shape=2) + [[ 1. 3.] + [ 3. 2.] + [ 2. 3.]] + + """ + return _random_helper(_internal._random_poisson, _internal._sample_poisson, + [lam], shape, dtype, ctx, out, kwargs) + + +def exponential(scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): + r"""Draw samples from an exponential distribution. + + Its probability density function is + + f(x; \frac{1}{\beta}) = \frac{1}{\beta} \exp(-\frac{x}{\beta}), + + for x > 0 and 0 elsewhere. \beta is the scale parameter, which is the + inverse of the rate parameter \lambda = 1/\beta. + + .. note:: exponential is not implemented for GPU yet. + + Parameters + ---------- + scale : float or NDArray + The scale parameter, \beta = 1/\lambda. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `scale` is + a scalar, output shape will be `(m, n)`. If `scale` + is an NDArray with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `scale`. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + ctx : Context + Device context of output. Default is current context. Overridden by + `scale.context` when `scale` is an NDArray. + out : NDArray + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.exponential(1) + [ 0.79587454] + + >>> mx.nd.random.exponential(1, shape=(2,)) + [ 0.89856035 1.25593066] + + >>> scale = mx.nd.array([1,2,3]) + >>> mx.nd.random.exponential(scale, shape=2) + [[ 0.41063145 0.42140478] + [ 2.59407091 10.12439728] + [ 2.42544937 1.14260709]] + + """ + return _random_helper(_internal._random_exponential, _internal._sample_exponential, + [1.0/scale], shape, dtype, ctx, out, kwargs) + + +def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): + """Draw random samples from a gamma distribution. + + Samples are distributed according to a gamma distribution parametrized + by *alpha* (shape) and *beta* (scale). + + .. note:: gamma is not implemented for GPU yet. + + Parameters + ---------- + alpha : float or NDArray + The shape of the gamma distribution. Should be greater than zero. + beta : float or NDArray + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `alpha` and + `beta` are scalars, output shape will be `(m, n)`. If `alpha` and `beta` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[alpha, beta)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + ctx : Context + Device context of output. Default is current context. Overridden by + `alpha.context` when `alpha` is an NDArray. + out : NDArray + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.gamma(1, 1) + [ 1.93308783] + + >>> mx.nd.random.gamma(1, 1, shape=(2,)) + [ 0.48216391 2.09890771] + + >>> alpha = mx.nd.array([1,2,3]) + >>> beta = mx.nd.array([2,3,4]) + >>> mx.nd.random.gamma(alpha, beta, shape=2) + [[ 3.24343276 0.94137681] + [ 3.52734375 0.45568955] + [ 14.26264095 14.0170126 ]] + + """ + return _random_helper(_internal._random_gamma, _internal._sample_gamma, + [alpha, beta], shape, dtype, ctx, out, kwargs) + + +def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, ctx=None, + out=None, **kwargs): + """Draw random samples from a negative binomial distribution. + + Samples are distributed according to a negative binomial distribution + parametrized by *k* (limit of unsuccessful experiments) and *p* (failure + probability in each experiment). Samples will always be returned as a + floating point data type. + + .. note:: negative_binomial is not implemented for GPU yet. + + Parameters + ---------- + k : float or NDArray + Limit of unsuccessful experiments, > 0. + p : float or NDArray + Failure probability in each experiment, >= 0 and <=1. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `k` and + `p` are scalars, output shape will be `(m, n)`. If `k` and `p` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[k, p)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + ctx : Context + Device context of output. Default is current context. Overridden by + `k.context` when `k` is an NDArray. + out : NDArray + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.negative_binomial(10, 0.5) + [ 4.] + + >>> mx.nd.random.negative_binomial(10, 0.5, shape=(2,)) + [ 3. 4.] + + >>> k = mx.nd.array([1,2,3]) + >>> p = mx.nd.array([0.2,0.4,0.6]) + >>> mx.nd.random.negative_binomial(k, p, shape=2) + [[ 3. 2.] + [ 4. 4.] + [ 0. 5.]] + + """ + return _random_helper(_internal._random_negative_binomial, + _internal._sample_negative_binomial, + [k, p], shape, dtype, ctx, out, kwargs) + + +def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=None, + out=None, **kwargs): + """Draw random samples from a generalized negative binomial distribution. + + Samples are distributed according to a generalized negative binomial + distribution parametrized by *mu* (mean) and *alpha* (dispersion). + *alpha* is defined as *1/k* where *k* is the failure limit of the + number of unsuccessful experiments (generalized to real numbers). + Samples will always be returned as a floating point data type. + + .. note:: negative_binomial is not implemented for GPU yet. + + Parameters + ---------- + mu : float or NDArray + Mean of the negative binomial distribution. + alpha : float or NDArray + Alpha (dispersion) parameter of the negative binomial distribution. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `mu` and + `alpha` are scalars, output shape will be `(m, n)`. If `mu` and `alpha` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[mu, alpha)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + ctx : Context + Device context of output. Default is current context. Overridden by + `mu.context` when `mu` is an NDArray. + out : NDArray + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.generalized_negative_binomial(10, 0.5) + [ 19.] + + >>> mx.nd.random.generalized_negative_binomial(10, 0.5, shape=(2,)) + [ 30. 21.] + + >>> mu = mx.nd.array([1,2,3]) + >>> alpha = mx.nd.array([0.2,0.4,0.6]) + >>> mx.nd.random.generalized_negative_binomial(mu, alpha, shape=2) + [[ 4. 0.] + [ 3. 2.] + [ 6. 2.]] + + """ + return _random_helper(_internal._random_generalized_negative_binomial, + _internal._sample_generalized_negative_binomial, + [mu, alpha], shape, dtype, ctx, out, kwargs) + + +def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs): + """Concurrent sampling from multiple multinomial distributions. + + .. note:: The input distribution must be normalized, i.e. `data` must sum to + 1 along its last dimension. + + Parameters + ---------- + data : NDArray + An *n* dimensional array whose last dimension has length `k`, where + `k` is the number of possible outcomes of each multinomial distribution. + For example, data with shape `(m, n, k)` specifies `m*n` multinomial + distributions each with `k` possible outcomes. + shape : int or tuple of ints + The number of samples to draw from each distribution. If shape is empty + one sample will be drawn from each distribution. + get_prob : bool + If true, a second array containing log likelihood of the drawn + samples will also be returned. + This is usually used for reinforcement learning, where you can provide + reward as head gradient w.r.t. this array to estimate gradient. + out : NDArray + Store output to an existing NDArray. + + Examples + -------- + >>> probs = mx.nd.array([[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]]) + >>> mx.nd.random.multinomial(probs) + [3 1] + + >>> mx.nd.random.multinomial(probs, shape=2) + [[4 4] + [1 2]] + + >>> mx.nd.random.multinomial(probs, get_prob=True) + [3 2] + + [-1.20397282 -1.60943794] + + """ + return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs) diff --git a/python/mxnet/random.py b/python/mxnet/random.py index 5754304b..3a13b3d7 100644 --- a/python/mxnet/random.py +++ b/python/mxnet/random.py @@ -17,18 +17,14 @@ # coding: utf-8 # pylint: disable=no-member, protected-access, unused-import, no-name-in-module +# pylint: disable=wildcard-import, unused-wildcard-import """Random number interface of MXNet.""" from __future__ import absolute_import import ctypes from .base import _LIB, check_call -from .ndarray._internal import _sample_uniform as uniform -from .ndarray._internal import _sample_normal as normal -from .ndarray._internal import _sample_gamma as gamma -from .ndarray._internal import _sample_exponential as exponential -from .ndarray._internal import _sample_poisson as poisson -from .ndarray._internal import _sample_negbinomial as negative_binomial -from .ndarray._internal import _sample_gennegbinomial as generalized_negative_binomial +from .ndarray.random import * + def seed(seed_state): """Seeds the random number generators in MXNet. diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py index 2694b4e5..da395b71 100644 --- a/python/mxnet/symbol/__init__.py +++ b/python/mxnet/symbol/__init__.py @@ -17,7 +17,7 @@ """Symbol API of MXNet.""" -from . import _internal, contrib, linalg, random, sparse +from . import _internal, contrib, linalg, sparse, random # pylint: disable=wildcard-import, redefined-builtin from .op import * from .symbol import * diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index 75ff7ede..2348801b 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -15,5 +15,245 @@ # specific language governing permissions and limitations # under the License. -"""Random Distribution Generator Symbol API of MXNet.""" -__all__ = [] +"""Random distribution generator Symbol API of MXNet.""" + +from ..base import numeric_types, _Null +from . import _internal +from .symbol import Symbol + + +__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial', + 'negative_binomial', 'generalized_negative_binomial'] + + +def _random_helper(random, sampler, params, shape, dtype, kwargs): + """Helper function for random generators.""" + if isinstance(params[0], Symbol): + for i in params[1:]: + assert isinstance(i, Symbol), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s."%(type(params[0]), type(i)) + return sampler(*params, shape=shape, dtype=dtype, **kwargs) + elif isinstance(params[0], numeric_types): + for i in params[1:]: + assert isinstance(i, numeric_types), \ + "Distribution parameters must all have the same type, but got " \ + "both %s and %s."%(type(params[0]), type(i)) + return random(*params, shape=shape, dtype=dtype, **kwargs) + + raise ValueError("Distribution parameters must be either Symbol or numbers, " + "but got %s."%type(params[0])) + + +def uniform(low=0, high=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval *[low, high)* + (includes *low*, but excludes *high*). + + Parameters + ---------- + low : float or Symbol + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float or Symbol + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and + `high` are scalars, output shape will be `(m, n)`. If `low` and `high` + are Symbols with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_uniform, _internal._sample_uniform, + [low, high], shape, dtype, kwargs) + + +def normal(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float or Symbol + Mean (centre) of the distribution. + scale : float or Symbol + Standard deviation (spread or width) of the distribution. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and + `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` + are Symbols with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_normal, _internal._sample_normal, + [loc, scale], shape, dtype, kwargs) + + +def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a Poisson distribution. + + Samples are distributed according to a Poisson distribution parametrized + by *lambda* (rate). Samples will always be returned as a floating point data type. + + .. note:: poisson is not implemented for GPU yet. + + Parameters + ---------- + lam : float or Symbol + Expectation of interval, should be >= 0. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `lam` is + a scalar, output shape will be `(m, n)`. If `lam` + is an Symbol with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `lam`. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_poisson, _internal._sample_poisson, + [lam], shape, dtype, kwargs) + + +def exponential(scale=1, shape=_Null, dtype=_Null, **kwargs): + r"""Draw samples from an exponential distribution. + + Its probability density function is + + f(x; \frac{1}{\beta}) = \frac{1}{\beta} \exp(-\frac{x}{\beta}), + + for x > 0 and 0 elsewhere. \beta is the scale parameter, which is the + inverse of the rate parameter \lambda = 1/\beta. + + .. note:: exponential is not implemented for GPU yet. + + Parameters + ---------- + scale : float or Symbol + The scale parameter, \beta = 1/\lambda. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `scale` is + a scalar, output shape will be `(m, n)`. If `scale` + is an Symbol with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `scale`. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_exponential, _internal._sample_exponential, + [1.0/scale], shape, dtype, kwargs) + + +def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a gamma distribution. + + Samples are distributed according to a gamma distribution parametrized + by *alpha* (shape) and *beta* (scale). + + .. note:: gamma is not implemented for GPU yet. + + Parameters + ---------- + alpha : float or Symbol + The shape of the gamma distribution. Should be greater than zero. + beta : float or Symbol + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `alpha` and + `beta` are scalars, output shape will be `(m, n)`. If `alpha` and `beta` + are Symbols with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[alpha, beta)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_gamma, _internal._sample_gamma, + [alpha, beta], shape, dtype, kwargs) + + +def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a negative binomial distribution. + + Samples are distributed according to a negative binomial distribution + parametrized by *k* (limit of unsuccessful experiments) and *p* (failure + probability in each experiment). Samples will always be returned as a + floating point data type. + + .. note:: negative_binomial is not implemented for GPU yet. + + Parameters + ---------- + k : float or Symbol + Limit of unsuccessful experiments, > 0. + p : float or Symbol + Failure probability in each experiment, >= 0 and <=1. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `k` and + `p` are scalars, output shape will be `(m, n)`. If `k` and `p` + are Symbols with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[k, p)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_negative_binomial, + _internal._sample_negative_binomial, + [k, p], shape, dtype, kwargs) + + +def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a generalized negative binomial distribution. + + Samples are distributed according to a generalized negative binomial + distribution parametrized by *mu* (mean) and *alpha* (dispersion). + *alpha* is defined as *1/k* where *k* is the failure limit of the + number of unsuccessful experiments (generalized to real numbers). + Samples will always be returned as a floating point data type. + + .. note:: negative_binomial is not implemented for GPU yet. + + Parameters + ---------- + mu : float or Symbol + Mean of the negative binomial distribution. + alpha : float or Symbol + Alpha (dispersion) parameter of the negative binomial distribution. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `mu` and + `alpha` are scalars, output shape will be `(m, n)`. If `mu` and `alpha` + are Symbols with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[mu, alpha)` pair. + dtype : {'float16','float32', 'float64'} + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_generalized_negative_binomial, + _internal._sample_generalized_negative_binomial, + [mu, alpha], shape, dtype, kwargs) + + +def multinomial(data, shape=_Null, get_prob=True, **kwargs): + """Concurrent sampling from multiple multinomial distributions. + + .. note:: The input distribution must be normalized, i.e. `data` must sum to + 1 along its last dimension. + + Parameters + ---------- + data : Symbol + An *n* dimensional array whose last dimension has length `k`, where + `k` is the number of possible outcomes of each multinomial distribution. + For example, data with shape `(m, n, k)` specifies `m*n` multinomial + distributions each with `k` possible outcomes. + shape : int or tuple of ints + The number of samples to draw from each distribution. If shape is empty + one sample will be drawn from each distribution. + get_prob : bool + If true, a second array containing log likelihood of the drawn + samples will also be returned. + This is usually used for reinforcement learning, where you can provide + reward as head gradient w.r.t. this array to estimate gradient. + """ + return _internal._sample_multinomial(data, shape, get_prob, **kwargs) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala index 71586d3d..210e61d8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala @@ -45,7 +45,7 @@ object Random { require(shape != null, "shape is required when out is not specified") outCopy = NDArray.empty(shape, ctx) } - NDArray.genericNDArrayFunctionInvoke("_sample_uniform", Seq(low, high), + NDArray.genericNDArrayFunctionInvoke("_random_uniform", Seq(low, high), Map("shape" -> outCopy.shape, "out" -> outCopy)) } @@ -72,7 +72,7 @@ object Random { require(shape != null, "shape is required when out is not specified") outCopy = NDArray.empty(shape, ctx) } - NDArray.genericNDArrayFunctionInvoke("_sample_normal", Seq.empty[NDArray], + NDArray.genericNDArrayFunctionInvoke("_random_normal", Seq.empty[NDArray], Map("loc" -> loc, "scale" -> scale, "shape" -> outCopy.shape, "out" -> outCopy)) } diff --git a/src/operator/random/multisample_op.cc b/src/operator/random/multisample_op.cc index f1264e5d..64a41bd1 100644 --- a/src/operator/random/multisample_op.cc +++ b/src/operator/random/multisample_op.cc @@ -140,7 +140,8 @@ DMLC_REGISTER_PARAMETER(MultiSampleParam); input_name_1, input_name_2, \ input_desc_1, input_desc_2, \ description) \ - NNVM_REGISTER_OP(sample_##distr) \ + NNVM_REGISTER_OP(_sample_##distr) \ + .add_alias("sample_" #distr) \ .describe(description()+std::string(ADD_FILELINE)) \ .set_num_inputs(num_inputs) \ .set_num_outputs(1) \ diff --git a/src/operator/random/sample_multinomial_op.cc b/src/operator/random/sample_multinomial_op.cc index b358b3b2..7032a6ec 100644 --- a/src/operator/random/sample_multinomial_op.cc +++ b/src/operator/random/sample_multinomial_op.cc @@ -29,7 +29,8 @@ namespace op { DMLC_REGISTER_PARAMETER(SampleMultinomialParam); -NNVM_REGISTER_OP(sample_multinomial) +NNVM_REGISTER_OP(_sample_multinomial) +.add_alias("sample_multinomial") .describe(R"code(Concurrent sampling from multiple multinomial distributions. *data* is an *n* dimensional array whose last dimension has length *k*, where diff --git a/src/operator/random/sample_multinomial_op.cu b/src/operator/random/sample_multinomial_op.cu index c2bc99b7..5b59b2af 100644 --- a/src/operator/random/sample_multinomial_op.cu +++ b/src/operator/random/sample_multinomial_op.cu @@ -26,7 +26,7 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(sample_multinomial) +NNVM_REGISTER_OP(_sample_multinomial) .set_attr("FCompute", SampleMultinomialForward); diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index ea6fdd54..4225cd65 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -47,7 +47,6 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialParam); // Add "uniform" alias for backward compatibility MXNET_OPERATOR_REGISTER_SAMPLE(_random_uniform, SampleUniformParam) .add_alias("uniform") -.add_alias("_sample_uniform") .add_alias("random_uniform") .describe(R"code(Draw random samples from a uniform distribution. @@ -68,7 +67,6 @@ Example:: // Add "normal" alias for backward compatibility MXNET_OPERATOR_REGISTER_SAMPLE(_random_normal, SampleNormalParam) .add_alias("normal") -.add_alias("_sample_normal") .add_alias("random_normal") .describe(R"code(Draw random samples from a normal (Gaussian) distribution. @@ -85,7 +83,6 @@ Example:: .set_attr("FComputeEx", SampleNormalEx_); MXNET_OPERATOR_REGISTER_SAMPLE(_random_gamma, SampleGammaParam) -.add_alias("_sample_gamma") .add_alias("random_gamma") .describe(R"code(Draw random samples from a gamma distribution. @@ -100,7 +97,6 @@ Example:: .set_attr("FComputeEx", SampleGammaEx_); MXNET_OPERATOR_REGISTER_SAMPLE(_random_exponential, SampleExponentialParam) -.add_alias("_sample_exponential") .add_alias("random_exponential") .describe(R"code(Draw random samples from an exponential distribution. @@ -114,7 +110,6 @@ Example:: .set_attr("FCompute", SampleExponential_); MXNET_OPERATOR_REGISTER_SAMPLE(_random_poisson, SamplePoissonParam) -.add_alias("_sample_poisson") .add_alias("random_poisson") .describe(R"code(Draw random samples from a Poisson distribution. @@ -129,7 +124,6 @@ Example:: .set_attr("FCompute", SamplePoisson_); MXNET_OPERATOR_REGISTER_SAMPLE(_random_negative_binomial, SampleNegBinomialParam) -.add_alias("_sample_negbinomial") .add_alias("random_negative_binomial") .describe(R"code(Draw random samples from a negative binomial distribution. @@ -145,7 +139,6 @@ Example:: .set_attr("FCompute", SampleNegBinomial_); MXNET_OPERATOR_REGISTER_SAMPLE(_random_generalized_negative_binomial, SampleGenNegBinomialParam) -.add_alias("_sample_gennegbinomial") .add_alias("random_generalized_negative_binomial") .describe(R"code(Draw random samples from a generalized negative binomial distribution. diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 01c8b0aa..decb9dfd 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -28,8 +28,7 @@ def check_with_device(device, dtype): { 'name': 'normal', 'symbol': mx.sym.random.normal, - 'multisymbol': mx.sym.sample_normal, - 'ndop': mx.random.normal, + 'ndop': mx.nd.random.normal, 'params': { 'loc': 10.0, 'scale': 0.5 }, 'inputs': [ ('loc',[ [ 0.0, 2.5 ], [ -9.75, -7.0 ] ]) , ('scale',[ [ 1.0, 3.7 ], [ 4.2, 1.5 ] ]) ], 'checks': [ @@ -40,8 +39,7 @@ def check_with_device(device, dtype): { 'name': 'uniform', 'symbol': mx.sym.random.uniform, - 'multisymbol': mx.sym.sample_uniform, - 'ndop': mx.random.uniform, + 'ndop': mx.nd.random.uniform, 'params': { 'low': -1.5, 'high': 3.0 }, 'inputs': [ ('low', [ [ 0.0, 2.5 ], [ -9.75, -1.0 ] ]) , ('high', [ [ 1.0, 3.7 ], [ 4.2, 10.5 ] ]) ], 'checks': [ @@ -55,8 +53,7 @@ def check_with_device(device, dtype): { 'name': 'gamma', 'symbol': mx.sym.random.gamma, - 'multisymbol': mx.sym.sample_gamma, - 'ndop': mx.random.gamma, + 'ndop': mx.nd.random.gamma, 'params': { 'alpha': 9.0, 'beta': 0.5 }, 'inputs': [ ('alpha', [ [ 0.0, 2.5 ], [ 9.75, 11.0 ] ]) , ('beta', [ [ 1.0, 0.7 ], [ 0.5, 0.3 ] ]) ], 'checks': [ @@ -67,20 +64,18 @@ def check_with_device(device, dtype): { 'name': 'exponential', 'symbol': mx.sym.random.exponential, - 'multisymbol': mx.sym.sample_exponential, - 'ndop': mx.random.exponential, - 'params': { 'lam': 4.0 }, - 'inputs': [ ('lam', [ [ 1.0, 8.5 ], [ 2.7 , 0.5 ] ]) ], + 'ndop': mx.nd.random.exponential, + 'params': { 'scale': 1.0/4.0 }, + 'inputs': [ ('scale', [ [ 1.0/1.0, 1.0/8.5 ], [ 1.0/2.7 , 1.0/0.5 ] ]) ], 'checks': [ - ('mean', lambda x, params: np.mean(x.astype(np.float64)) - 1.0 / params['lam'], tol), - ('std', lambda x, params: np.std(x.astype(np.float64)) - 1.0 / params['lam'], tol) + ('mean', lambda x, params: np.mean(x.astype(np.float64)) - params['scale'], tol), + ('std', lambda x, params: np.std(x.astype(np.float64)) - params['scale'], tol) ] }, { 'name': 'poisson', 'symbol': mx.sym.random.poisson, - 'ndop': mx.random.poisson, - 'multisymbol': mx.sym.sample_poisson, + 'ndop': mx.nd.random.poisson, 'params': { 'lam': 4.0 }, 'inputs': [ ('lam', [ [ 1.0, 8.5 ], [ 2.7 , 0.5 ] ]) ], 'checks': [ @@ -91,10 +86,9 @@ def check_with_device(device, dtype): { 'name': 'neg-binomial', 'symbol': mx.sym.random.negative_binomial, - 'multisymbol': mx.sym.sample_negative_binomial, - 'ndop': mx.random.negative_binomial, + 'ndop': mx.nd.random.negative_binomial, 'params': { 'k': 3, 'p': 0.4 }, - 'inputs': [ ('k', [ [ 20, 49 ], [ 15 , 16 ] ]) , ('p', [ [ 0.4 , 0.77 ], [ 0.5, 0.84 ] ]) ], + 'inputs': [ ('k', [ [ 3, 4 ], [ 5 , 6 ] ]) , ('p', [ [ 0.4 , 0.77 ], [ 0.5, 0.84 ] ]) ], 'checks': [ ('mean', lambda x, params: np.mean(x.astype(np.float64)) - params['k'] * (1.0 - params['p']) / params['p'], tol), ('std', lambda x, params: np.std(x.astype(np.float64)) - np.sqrt(params['k'] * (1.0 - params['p']))/params['p'], tol) @@ -103,8 +97,7 @@ def check_with_device(device, dtype): { 'name': 'gen-neg-binomial', 'symbol': mx.sym.random.generalized_negative_binomial, - 'multisymbol': mx.sym.sample_generalized_negative_binomial, - 'ndop': mx.random.generalized_negative_binomial, + 'ndop': mx.nd.random.generalized_negative_binomial, 'params': { 'mu': 2.0, 'alpha': 0.3 }, 'inputs': [ ('mu', [ [ 2.0, 2.5 ], [ 1.3, 1.9 ] ]) , ('alpha', [ [ 1.0, 0.1 ], [ 0.2, 0.5 ] ]) ], 'checks': [ @@ -133,6 +126,24 @@ def check_with_device(device, dtype): for check_name, check_func, tol in symbdic['checks']: assert np.abs(check_func(ret1, params)) < tol, "ndarray test: %s check for `%s` did not pass" % (check_name, name) + # check multi-distribution sampling, only supports cpu for now + if device.device_type == 'cpu': + params = {'shape': shape, 'dtype': dtype, 'ctx': device} + params.update({k : mx.nd.array(v, ctx=device, dtype=dtype) for k, v in symbdic['inputs']}) + mx.random.seed(128) + ret1 = ndop(**params).asnumpy() + mx.random.seed(128) + ret2 = ndop(**params).asnumpy() + assert device.device_type == 'gpu' or same(ret1, ret2), \ + "ndarray test: `%s` should give the same result with the same seed" % name + for i in range(2): + for j in range(2): + stats = {k : v[i][j] for k, v in symbdic['inputs']} + for check_name, check_func, tol in symbdic['checks']: + err = np.abs(check_func(ret2[i,j], stats)) + assert err < tol, "%f vs %f: symbolic test: %s check for `%s` did not pass" % (err, tol, check_name, name) + + # check symbolic symbol = symbdic['symbol'] X = mx.sym.Variable("X") @@ -159,7 +170,7 @@ def check_with_device(device, dtype): # check multi-distribution sampling, only supports cpu for now if device.device_type == 'cpu': - symbol = symbdic['multisymbol'] + symbol = symbdic['symbol'] params = { 'shape' : shape, 'dtype' : dtype } single_param = len(symbdic['inputs']) == 1; v1 = mx.sym.Variable('v1') @@ -192,7 +203,7 @@ def test_sample_multinomial(): dx = mx.nd.ones_like(x) mx.contrib.autograd.mark_variables([x], [dx]) with mx.autograd.record(): - y, prob = mx.nd.sample_multinomial(x, shape=1000, get_prob=True) + y, prob = mx.nd.random.multinomial(x, shape=1000, get_prob=True) r = prob * 5 r.backward() @@ -212,5 +223,5 @@ def test_sample_multinomial(): if __name__ == '__main__': - test_random() - test_sample_multinomial() + import nose + nose.runmodule()