You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
That's actually expected, though we should have a proper error message for it. The candle sort operator uses a bitonic sort which requires the whole data to fit in a single thread-group/cuda-block (the same approach is used by llama.cpp), the idea there is to use this operator for things like mixture of experts where the number of element to sort is very small but it cannot apply to larger sets of elements.
Yes I realized it's bitonic sort once I went through the code, didn't realize it's by design.
A generic implementation would be helpful (in my case speeding up token sampling for autoregressive language models) and I did some digging around this.
Torch delegates cuda sort to thrust - the current versions of thrust and cub resides cccl. NVIDIA/cccl is not supported by cudarc yet and my lowkey efforts to bindgen was a spectacular failure.
And from what I could gather, Torch relies on MPSGraph.argsort() to do the sorting. Yet again, MPSGraph is yet to be a part of metal-rs.
According to this implementation, cub uses an implementation of RadixSort.
I'm working on an implementation of it and if things go well and the port to metal works I'll probably create a PR where I'll call the bitonic sort kernel if ncols_pad < MaxThreadsPerGroup otherwise call a DeviceRadixSort kernel.
Reproduction:
Edit: removed incorrect diagnosis.
The text was updated successfully, but these errors were encountered: