-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Significant speedup for large datasets: In [2]: %timeit current_sample(1529*8192) 12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [3]: %timeit np_sample(1529*8192) 641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1d706f8
to
0eda17d
Compare
Have you tried mx ndarray shuffle? |
It doesn't perform well. At least not with the naive approach of:
What did you have in mind? |
Would you try this and provide the timings? def __iter__(self):
indices = mx.nd.arange(self._length, dtype='int32').reshape((1, self._length)) # may look weird but anyway
mx.nd.random.shuffle(indices, out=indices)
return iter(indices[0].asnumpy()) -------- EDIT -------- Please ignore the above. The reshaping was my mistake. It should be this. The performance of shuffle is affected by OMP_NUM_THREADS env variable. Do not set OMP_NUM_THREADS or set it as the number of physical cores. def __iter__(self):
indices = mx.nd.arange(self._length, dtype='int32')
mx.nd.random.shuffle(indices, out=indices)
return iter(indices.asnumpy()) |
I guess this is good enough since we want scalars in the end anyway |
If the performance penalty is not too much, I think that using mxnet's shuffle would be preferable. Depending on external global RNG is not a good idea. |
Thanks @asitstands . Comparing the numpy code to the mx.nd code you provided results in the following performance on my machine:
So relying on mx.nd.random.shuffle + asnumpy seems to add an extra second. Regarding RNG, our test cases set both numpy and mxnet seeds. I believe other parts of mxnet also use numpy random, so it may be good to document that both seeds must be set to get deterministic behavior. If this is the only place numpy.random is used it may be worth the extra second to stay consistent? |
Numpy's and python's global RNGs are used here and there in mxnet. In my opinion they should be removed all in someday :) It annoys people who work with subtle probabilistic reasonings. Anyway, it looks like currently using numpy is the best. Thanks for quick test. |
@leezu If possible, could I ask one more test? In my experiments, mxnet's shuffle outperforms numpy's when array size is large. Would you please test with somewhat larger arrays? |
@asitstands above timings where taken with an array size of 12525568. I have tried again with size 40000 and get the following:
Are you taking the overhead of converting to scalars into account? |
Thanks @leezu. I wish this discussion would not bother you too much. Here is my test code. import time
import mxnet as mx
import numpy as np
n = 40000
start = time.time()
for i in range(10000):
x = mx.nd.arange(n)
mx.random.shuffle(x, out=x)
y = iter(x.asnumpy())
end = time.time()
print("mx elapsed time: ", end - start)
start = time.time()
for i in range(10000):
x = np.arange(n)
np.random.shuffle(x)
y = iter(x)
end = time.time()
print("np elapsed time: ", end - start) On i7-3770K 3.50GHz, the result is
On two Xeon(R) E5-2680 v4 2.40GHz, the result is
As I increase |
@asitstands I guess the difference between our experiments is that I used a optimized numpy from conda and the standard mxnet pypi build. Using both optimized numpy and an optimized mxnet build on AWS p3 instance I do observe like you that mxnet is faster for small sizes (40000): ~500μs vs ~800μs of numpy For large sizes (12525568) the |
I think that conda has no special optimization for numpy's shuffle. Numpy's shuffle uses |
On my personal computer indeed I experience the same speed-up of mxnet compared to numpy. On the other machines the results I quoted above still stand. I guess in the end this depends a lot on the particular system and the build options of the libraries, though it is strange given your explanation about the implementation. As this code is only run once per epoch to shuffle the dataset I believe it is not that important if it takes 200ms or 500ms for large datasets. It was just unbearable that it took 10s+ before. I don't have a strong feeling about changing it, though I won't propose such change myself given that I had mixed results depending on the computer. If you open a PR and someone is willing to merge it I won't mind. |
I'll test on some other environments including AWS and make a PR if I'm sure that the performance hit is not usual. |
Sounds great, thanks!
|
Significant speedup for large datasets: In [2]: %timeit current_sample(1529*8192) 12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [3]: %timeit np_sample(1529*8192) 641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Significant speedup for large datasets: In [2]: %timeit current_sample(1529*8192) 12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [3]: %timeit np_sample(1529*8192) 641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Significant speedup for large datasets: In [2]: %timeit current_sample(1529*8192) 12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [3]: %timeit np_sample(1529*8192) 641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Significant speedup for large datasets:
In [2]: %timeit current_sample(1529*8192)
12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [3]: %timeit np_sample(1529*8192)
641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)