-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: refactor mmq, dmmv, mmvq #7716
CUDA: refactor mmq, dmmv, mmvq #7716
Conversation
The compilation time increases due to the additional template instances but I think the increase is acceptable:
|
79d415d
to
8b6962d
Compare
This PR seems to increase the throughput on the server but not by much:
Hardware is 1x RTX 4090. Command
|
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { | ||
return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 : | ||
type == GGML_TYPE_Q4_1 ? dequantize_q4_1 : | ||
type == GGML_TYPE_Q5_0 ? dequantize_q5_0 : | ||
type == GGML_TYPE_Q5_1 ? dequantize_q5_1 : | ||
type == GGML_TYPE_Q8_0 ? dequantize_q8_0 : | ||
type == GGML_TYPE_F16 ? convert_f16 : | ||
nullptr; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could also be moved to ggml_cuda_type_traits
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are circular dependencies between common.cuh
and dequantize.cuh
okay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but there is probably too much in common.cuh
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So should we for now just keep this as-is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to you, it's just a suggestion.
ggml-cuda/mmq.cuh
Outdated
|
||
// ------------------------------------------------------------------------------------------------------------------------------------- | ||
|
||
static constexpr __device__ int get_need_sum(ggml_type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of these could potentially be moved too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These things are MMQ-specific and I would prefer to keep them together. I would have made another template struct if I was aware of a simple way to do this that also includes the functions for loading tiles and doing the vector dot products. But I don't know how to do this in such a way that I can still correctly pass the template arguments without having to resort to preprocessor macros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible to specialize a template partially, eg:
template<int x, int y, ggml_type type>
struct mmq_type_traits;
template<int x, int y>
struct mmq_type_traits<x, y, GGML_TYPE_F16> {
static constexpr int qk = 1;
static constexpr int qr = 1;
static constexpr int z = x+2;
};
This PR refactors the
mul_mat_q
and to a lesser extent thedequantize_mul_mat_vec
andmul_mat_vec_q
kernels. The intent is to simplify the code in preparation for #7676 . List of changes:ggml_type
of the buffer to the CUDA kernel and then useconstexpr __device__
functions to fetch the corresponding arguments at compile time. This simplifies the use of the template and ensures that arguments are passed consistently (without e.g. accidentally not changing one of them when copy-pasting). Passing only aggml_type
for a template argument also solves a lot of annoyances to do with device functions in host code; I think it's greatly preferable to separate the two as much as possible.ggml_type
. The shared memory is then allocated dynamically and the pointers to shared memory can be set generically.mul_mat_q
into the functions for loading tiles and calculating dot products. This is so a given__CUDA_ARCH__
can more easily use tensor cores without having to touch the rest of the code.extern
templates for faster multi-threaded compilation.The performance of MMQ changes due to the dynamic tile sizes. The tile sizes on master were tuned for the same GPUs on LLaMA 2 7b, the following numbers are for LLaMA 3 8b. For context, the optimal tile sizes for a matrix multiplication depend strongly on the matrix shapes and the number of streaming multiprocessors on a GPU. So the numbers on master are likely "overfit" to these specific model and GPU combinations and I think the relative performance change on GPUs that I do not have access to will be better than what I report in this PR.
On my GPUs the performance for small batch sizes becomes significantly better. The performance for large batch sizes on NVIDIA GPUs stay essentially constant on average, the performance for my RX 6800 gets worse for large batch sizes for some quants. I don't understand why this is happening and I cannot do a git bisect either because my WIP commits only work correctly for q8_0 where there is no performance regression. My fundamental stance is that I am fine with supporting AMD via HIP as long as it's not too much effort. And in this case trying to figure out the exact problem is too much effort so I will not do it.
Specific numbers:
RTX 3090, no LLAMA_CUDA_FORCE_MMQ
RTX 4090, no LLAMA_CUDA_FORCE_MMQ
RTX 3090, LLAMA_CUDA_FORCE_MMQ
RTX 4090, LLAMA_CUDA_FORCE_MMQ
P40
RX 6800