Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-424] dtype option for multinomial #10970

Merged
merged 6 commits into from
Jun 1, 2018

Conversation

asitstands
Copy link
Contributor

Description

This PR adds dtype option to set the data type of the sample output array of random.multinomial, which is fixed as 'int32' in the current implementation. The default value is int32.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@asitstands asitstands requested a review from szha as a code owner May 16, 2018 09:02
Kernel<SampleMultinomialKernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of 2 layer switches is very slow to compile. Why do we need type support for output?

Copy link
Contributor Author

@asitstands asitstands May 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes the multinomial samples need further processing in floating point arithmetic, so the samples need to be copied into a new array of floating point type. The copy slows down the training. For example, in RBM, the samples need to be applied by linalg.gemm which supports only floating point arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple cast shouldn't cost that much?

This kind of nested switches are really slow to compile and makes the binary much bigger.
We need to make sure it really justifies the cost

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The binary size increases about 0.1% for both shared and static library (CUDA, CUDNN, MKL). Compiling mxnet already takes quite long time, so the relative increase of the compile time is also tiny.

I'm working with some variants of RBM and the use of .astype('float32') in several places increases training time over 20%. In the case of usual basic RBM, it increases about 10% of training time in my test for mnist. Of course, it depends on the hyperparameters and data. However, I think that, in general, the cost cannot be ignored for applications using heavy Monte Carlo samplings of discrete states.

@szha szha removed their request for review May 21, 2018 21:34
@asmushetzel
Copy link
Contributor

This change would make things a bit more consistent with other samplers. For all the rest (uniform , gamma etc), we consistently use a floating point type as return value (by default 32bit) though some of them (Poisson, negative binomial) are distributions on integer values. And for the cases that I have seen where these samplers get used in practice, in fact the users needed floating-point data for further processing.
We hardly can't change the default type of the multinomial anymore, but I think we should add floating point as a result type.
The nested compile switches are used a lot in MXNet, not sure whether this is an issue when used also here.

@piiswrong
Copy link
Contributor

Ok. We can add this, but using float to represent int is only accurate within certain ranges. Please add checks for input dimensions for various types.

@asitstands
Copy link
Contributor Author

asitstands commented May 30, 2018

I added the check at SampleMultinomialOpShape. It ensures that the size of the last dimension of the input array is less than or equal to 2 << (the number of mantissa bits of dtype - 1) for floating point types or std::numeric_limits::max() for integer types. A test for this is also added.

@@ -67,6 +70,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
const TShape& ishape = (*in_attrs)[0];
if (!ishape.ndim()) return false;

MSHADOW_TYPE_SWITCH(param.dtype, DType, {
CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to output a message saying why it failed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an error message.

@piiswrong
Copy link
Contributor

otherwise LGTM

@piiswrong piiswrong merged commit a27b52e into apache:master Jun 1, 2018
@asitstands asitstands mentioned this pull request Jun 27, 2018
5 tasks
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* dtype option for multinomial

* Add missing test for uint8

* Add check to ensure dtype has a sufficient precision.

* Fix lint

* Error message for the dtype precision check

* Retrigger CI
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* dtype option for multinomial

* Add missing test for uint8

* Add check to ensure dtype has a sufficient precision.

* Fix lint

* Error message for the dtype precision check

* Retrigger CI
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants