Skip to content

Commit

Permalink
kompute: add mul_mat_q4_k shader (llama/10097)
Browse files Browse the repository at this point in the history
This is a more or less direct translation from the Metal implementation
to GLSL.

Signed-off-by: Sergio Lopez <slp@redhat.com>
  • Loading branch information
slp authored and ggerganov committed Nov 4, 2024
1 parent 9a92a87 commit 9d4d799
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
kompute-shaders/op_mul_mat_q8_0.comp
kompute-shaders/op_mul_mat_q4_0.comp
kompute-shaders/op_mul_mat_q4_1.comp
kompute-shaders/op_mul_mat_q4_k.comp
kompute-shaders/op_mul_mat_q6_k.comp
kompute-shaders/op_getrows_f32.comp
kompute-shaders/op_getrows_f16.comp
Expand Down Expand Up @@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
shaderop_mul_mat_q8_0.h
shaderop_mul_mat_q4_0.h
shaderop_mul_mat_q4_1.h
shaderop_mul_mat_q4_k.h
shaderop_mul_mat_q6_k.h
shaderop_getrows_f32.h
shaderop_getrows_f16.h
Expand Down
42 changes: 42 additions & 0 deletions src/ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "shaderop_mul_mat_q8_0.h"
#include "shaderop_mul_mat_q4_0.h"
#include "shaderop_mul_mat_q4_1.h"
#include "shaderop_mul_mat_q4_k.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
#include "shaderop_getrows_f32.h"
Expand Down Expand Up @@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
}

static void ggml_vk_mul_mat_q4_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
int32_t ne1, int32_t r2, int32_t r3
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);

struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
} pushConsts {
0, 0, 0,
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
};

std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) {
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(__func__);
s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get());
}
seq.record<kp::OpAlgoDispatch>(s_algo);
}

static void ggml_vk_mul_mat_q6_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
Expand Down Expand Up @@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_K:
return true;
default:
;
Expand Down Expand Up @@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
);
break;
case GGML_TYPE_Q4_K:
ggml_vk_mul_mat_q4_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
);
break;
case GGML_TYPE_Q6_K:
ggml_vk_mul_mat_q6_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
Expand Down

0 comments on commit 9d4d799

Please sign in to comment.