Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Metal implementation for all k_quants (ggerganov#1807)
* metal : improve q4_K 28.3 -> 26.0 ms/token by avoiding a branch in the calculation of the scales. * metal : small improvement for Q4_K * metal : still optimizing Q4_K This commit pushes it down to 25.3 ms / token. The crazy idea of using 6 bits for the scales is really costly on Metal: if I remove the bit fiddling necessary to make the block scales, time goes almost to the Q4_0 23 ms/token. Before pushing the k-quants upstream I had a Q4_K variant that had used 8-bit scales. It wasn't more accurate, used 0.125 bits more per weight, was running slightly slower on the CPU (due to the larger model size and being memory bound there), and the difference was entirely negligible under CUDA. So, I decided to publish the version with 6-bit scales. Perhaps I should re-consider and change to 8-bit scales? * metal : some more optimizations Q2_K: 25.4 ms/token Q6_K: 27.3 ms/token Q4_0: 22.8 ms/token Q4_1: 23.1 ms/token * metal : Q3_K support Something is not quite right yet. * metal : Q5_K support Initial version achieves 31.2 ms/token, 210 GB/s * metal : still not able to figure out why q3_K does not work * Minor * metal : yet another failed attempt to make q3_K work * metal : optimize Q5_K 31.2 ms -> 27.8 ms. 250 GB/s. * metal : q3_K still not working Adding a heavily commented q3_K metal kernel to explain my obviously faulty logic. Perhaps someone could spot the issue? * metal : q3_K finally working Not optimized at all. What was the issue? The scales are not 4-bytes aligned, and I was accessing them with a uint32_t pointer. When I tried that on CUDA, I got an error (illegal memory access) and added a memcpy to a local array of 3 uint32_t's. But on Metal it told me there is no memcpy, so I tried accessing directly. There is no error, just garbage results. At some point I did try accessing the scales with an uint16_t pointer (the scales are for sure 2-byte aligned), but was still getting garbage. I guess, there must have been another bug. No access to scales is via a uint16_t pointer and, after starting from scratch from the C dequantize function, it finally works. * metal : Q3_K 1st optimization pass * metal : Q3_K second optimization pass - 29.6 ms/token * metal : Q3_K cleanup * metal : fixed accidentally broken Q2_K --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
- Loading branch information