-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-424] dtype option for multinomial #10970
Conversation
Kernel<SampleMultinomialKernel, xpu>::Launch( | ||
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int>(), | ||
param.get_prob ? outputs[1].dptr<DType>() : nullptr); | ||
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kind of 2 layer switches is very slow to compile. Why do we need type support for output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes the multinomial samples need further processing in floating point arithmetic, so the samples need to be copied into a new array of floating point type. The copy slows down the training. For example, in RBM, the samples need to be applied by linalg.gemm
which supports only floating point arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A simple cast shouldn't cost that much?
This kind of nested switches are really slow to compile and makes the binary much bigger.
We need to make sure it really justifies the cost
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The binary size increases about 0.1% for both shared and static library (CUDA, CUDNN, MKL). Compiling mxnet already takes quite long time, so the relative increase of the compile time is also tiny.
I'm working with some variants of RBM and the use of .astype('float32')
in several places increases training time over 20%. In the case of usual basic RBM, it increases about 10% of training time in my test for mnist. Of course, it depends on the hyperparameters and data. However, I think that, in general, the cost cannot be ignored for applications using heavy Monte Carlo samplings of discrete states.
This change would make things a bit more consistent with other samplers. For all the rest (uniform , gamma etc), we consistently use a floating point type as return value (by default 32bit) though some of them (Poisson, negative binomial) are distributions on integer values. And for the cases that I have seen where these samplers get used in practice, in fact the users needed floating-point data for further processing. |
Ok. We can add this, but using float to represent int is only accurate within certain ranges. Please add checks for input dimensions for various types. |
I added the check at |
@@ -67,6 +70,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs, | |||
const TShape& ishape = (*in_attrs)[0]; | |||
if (!ishape.ndim()) return false; | |||
|
|||
MSHADOW_TYPE_SWITCH(param.dtype, DType, { | |||
CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to output a message saying why it failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an error message.
otherwise LGTM |
* dtype option for multinomial * Add missing test for uint8 * Add check to ensure dtype has a sufficient precision. * Fix lint * Error message for the dtype precision check * Retrigger CI
* dtype option for multinomial * Add missing test for uint8 * Add check to ensure dtype has a sufficient precision. * Fix lint * Error message for the dtype precision check * Retrigger CI
Description
This PR adds
dtype
option to set the data type of the sample output array ofrandom.multinomial
, which is fixed as 'int32' in the current implementation. The default value isint32
.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.