diff --git a/CMakeLists.txt b/CMakeLists.txt index cf37d5bb242ac..b1d6afbbcfa8d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -416,6 +416,8 @@ if (LLAMA_CUDA) list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_USE_CUDA) add_compile_definitions(GGML_CUDA_USE_GRAPHS) @@ -588,6 +590,8 @@ if (LLAMA_HIPBLAS) list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu") file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA) diff --git a/Makefile b/Makefile index 802ee6a47654c..895c62f84def0 100644 --- a/Makefile +++ b/Makefile @@ -444,6 +444,7 @@ ifdef LLAMA_CUBLAS endif OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu)) +OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/mmq*.cu)) ifdef LLAMA_CUDA_FA_ALL_QUANTS OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu)) else diff --git a/ggml-common.h b/ggml-common.h index 77e6bfba4b11b..e8efceb760d40 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -123,12 +123,18 @@ typedef sycl::half2 ggml_half2; #define QI1_S (QK_K / (4*QR1_S)) #define QR1_S 8 +#define QI1_M (QK_K / (4*QR1_M)) +#define QR1_M 8 + #define QI4_NL (QK4_NL / (4*QR4_NL)) #define QR4_NL 2 #define QI4_XS (QK_K / (4*QR4_XS)) #define QR4_XS 8 +#define QI3_S (QK_K / (4*QR3_S)) +#define QR3_S 8 + #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP #define QK4_0 32 diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c81c6a0d783be..dad8a9e2dafe7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -633,88 +633,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { // cuda split buffer -static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) { - int64_t min_compute_capability = INT_MAX; - int64_t max_compute_capability = INT_MIN; +static int64_t get_row_rounding(const std::array & tensor_split) { + int64_t row_rounding = 0; for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) { - if (min_compute_capability > ggml_cuda_info().devices[id].cc) { - min_compute_capability = ggml_cuda_info().devices[id].cc; - } - if (max_compute_capability < ggml_cuda_info().devices[id].cc) { - max_compute_capability = ggml_cuda_info().devices[id].cc; - } + if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) { + continue; } - } -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - switch(type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - return max_compute_capability >= CC_RDNA2 ? 128 : 64; - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return 1; - case GGML_TYPE_Q2_K: - return max_compute_capability >= CC_RDNA2 ? 128 : 32; - case GGML_TYPE_Q3_K: - return min_compute_capability < CC_RDNA2 ? 128 : 64; - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - return max_compute_capability >= CC_RDNA2 ? 128 : 64; - default: - GGML_ASSERT(false); - } -#else - switch(type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return max_compute_capability >= CC_VOLTA ? 128 : 64; - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - return 64; - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return 1; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - return max_compute_capability >= CC_VOLTA ? 128 : 64; - case GGML_TYPE_Q6_K: - return 64; - default: - GGML_ASSERT(false); + const int cc = ggml_cuda_info().devices[id].cc; + row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc))); } -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return row_rounding; } static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) { const int64_t nrows = ggml_nrows(tensor); - const int64_t rounding = get_row_rounding(tensor->type, tensor_split); + const int64_t rounding = get_row_rounding(tensor_split); *row_low = id == 0 ? 0 : nrows*tensor_split[id]; *row_low -= *row_low % rounding; @@ -1499,7 +1433,7 @@ static void ggml_cuda_op_mul_mat( // for multi GPU, get the row boundaries from tensor split // and round to mul_mat_q tile sizes if (split) { - const int64_t rounding = get_row_rounding(src0->type, tensor_split); + const int64_t rounding = get_row_rounding(tensor_split); if (id != 0) { dev[id].row_low = ne01*tensor_split[id]; diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 22872ca5c1d81..90a0a81ead789 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -160,7 +160,7 @@ #endif #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels -#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available +#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -484,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope( return powf(base, exph); } +template +struct ggml_cuda_type_traits; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = 1; + static constexpr int qr = 1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_0; + static constexpr int qr = QR4_0; + static constexpr int qi = QI4_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_1; + static constexpr int qr = QR4_1; + static constexpr int qi = QI4_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_0; + static constexpr int qr = QR5_0; + static constexpr int qi = QI5_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_1; + static constexpr int qr = QR5_1; + static constexpr int qi = QI5_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK8_0; + static constexpr int qr = QR8_0; + static constexpr int qi = QI8_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_K; + static constexpr int qi = QI2_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_K; + static constexpr int qi = QI3_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_K; + static constexpr int qi = QI4_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_K; + static constexpr int qi = QI5_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR6_K; + static constexpr int qi = QI6_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XXS; + static constexpr int qi = QI2_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XS; + static constexpr int qi = QI2_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_S; + static constexpr int qi = QI2_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_XXS; + static constexpr int qi = QI3_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_S; + static constexpr int qi = QI1_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_M; + static constexpr int qi = QI1_M; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_NL; + static constexpr int qr = QR4_NL; + static constexpr int qi = QI4_NL; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_S; + static constexpr int qi = QI3_S; +}; + +static int get_mmq_x_max_host(const int cc) { +#ifdef CUDA_USE_TENSOR_CORES + return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; +#else + return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; +#endif // CUDA_USE_TENSOR_CORES +} + +// Round rows to this value for --split-mode row: +static int get_mmq_y_host(const int cc, const int mmq_x) { + return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64; +} + ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 47d4d5d9e91da..174489e0665d3 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -422,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int v.y = x[ib + iqs + 1]; } -template +static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 : + type == GGML_TYPE_Q4_1 ? dequantize_q4_1 : + type == GGML_TYPE_Q5_0 ? dequantize_q5_0 : + type == GGML_TYPE_Q5_1 ? dequantize_q5_1 : + type == GGML_TYPE_Q8_0 ? dequantize_q8_0 : + type == GGML_TYPE_F16 ? convert_f16 : + nullptr; +} + +template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - // qk = quantized weights per x block - // qr = number of quantized weights per data value in x block + constexpr int qk = ggml_cuda_type_traits::qk; // quantized weights per x block + constexpr int qr = ggml_cuda_type_traits::qr; // number of quantized weights per data value in x block + constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type); + const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { @@ -493,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -502,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -511,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -520,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -529,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -580,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec<1, 1, convert_f16> + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index ebe1dc5c8bb22..58799e4caf6f8 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -1,1450 +1,4 @@ #include "mmq.cuh" -#include "vecdotq.cuh" - -typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); -typedef void (*load_tiles_cuda_t)( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); -typedef float (*vec_dot_q_mul_mat_cuda_t)( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); -typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); -typedef void (mul_mat_q_t)( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst); - -struct mmq_arch_config_t { - int x; - int y; - int nwarps; -}; - -struct mmq_config_t { - mmq_arch_config_t rdna2; - mmq_arch_config_t rdna1; - mmq_arch_config_t ampere; - mmq_arch_config_t pascal; -}; - -constexpr mmq_config_t MMQ_CONFIG_Q4_0 = { -// x y nwarps - { 64, 128, 8}, - { 64, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - { 64, 128, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q4_1 = { -// x y nwarps - { 64, 128, 8}, - { 64, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - { 64, 128, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q5_0 = { -// x y nwarps - { 64, 128, 8}, - { 64, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - {128, 64, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q5_1 = { -// x y nwarps - { 64, 128, 8}, - { 64, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - {128, 64, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q8_0 = { -// x y nwarps - { 64, 128, 8}, - { 64, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - {128, 64, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q2_K = { -// x y nwarps - { 64, 128, 8}, - {128, 32, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - { 64, 128, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q3_K = { -// x y nwarps - {128, 64, 8}, - { 32, 128, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - {128, 128, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q4_K = { -// x y nwarps - { 64, 128, 8}, - { 32, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - { 64, 128, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q5_K = { -// x y nwarps - { 64, 128, 8}, - { 32, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - { 64, 128, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; -constexpr mmq_config_t MMQ_CONFIG_Q6_K = { -// x y nwarps - { 64, 128, 8}, - { 32, 64, 8}, -#ifdef CUDA_USE_TENSOR_CORES - { 4, 32, 4}, -#else - { 64, 64, 4}, -#endif // CUDA_USE_TENSOR_CORES - { 64, 64, 8}, -}; - -// ------------------------------------------------------------ - -template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - GGML_UNUSED(x_sc); - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; - - *x_ql = tile_x_qs; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q4_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_0; - const int kqsx = k % QI4_0; - - const block_q4_0 * bx0 = (const block_q4_0 *) vx; - - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { - int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const float * x_dmf = (const float *) x_dm; - - int u[2*VDR_Q4_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; - } - - return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; - - *x_ql = tile_x_qs; - *x_dm = tile_x_dm; -} - -template static __device__ __forceinline__ void load_tiles_q4_1( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_1; - const int kqsx = k % QI4_1; - - const block_q4_1 * bx0 = (const block_q4_1 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { - int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; - } -} - -static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - - int u[2*VDR_Q4_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; - } - - return vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; - - *x_ql = tile_x_ql; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q5_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_0; - const int kqsx = k % QI5_0; - - const block_q5_0 * bx0 = (const block_q5_0 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; - - const int ql = get_int_from_uint8(bxi->qs, kqsx); - const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { - int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - int u[2*VDR_Q5_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; - } - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - - -template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; -} - -template static __device__ __forceinline__ void load_tiles_q5_1( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_1; - const int kqsx = k % QI5_1; - - const block_q5_1 * bx0 = (const block_q5_1 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; - - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { - int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; - } -} - -static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; - - int u[2*VDR_Q5_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; - } - - return vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; - - *x_ql = tile_x_qs; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q8_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI8_0; - const int kqsx = k % QI8_0; - float * x_dmf = (float *) x_dm; - - const block_q8_0 * bx0 = (const block_q8_0 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { - int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q2_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI2_K; - const int kqsx = k % QI2_K; - - const block_q2_K * bx0 = (const block_q2_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { - int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); - - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); - } -} - -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const int kbx = k / QI2_K; - const int ky = (k % QI2_K) * QR2_K; - const float * y_df = (const float *) y_ds; - - int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); - const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); - -#pragma unroll - for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { - v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; - } - - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; - - const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; - return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; - __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_qh = tile_x_qh; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q3_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI3_K; - const int kqsx = k % QI3_K; - - const block_q3_K * bx0 = (const block_q3_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { - int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); - - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); - - const int ksc = k % (QI3_K/4); - - const int ksc_low = ksc % (QI3_K/8); - const int shift_low = 4 * (ksc / (QI3_K/8)); - const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; - - const int ksc_high = QI3_K/8; - const int shift_high = 2 * ksc; - const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; - - const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; - } -} - -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kbx = k / QI3_K; - const int ky = (k % QI3_K) * QR3_K; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - - int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); - const int shift = 2 * ((ky % 32) / 8); - const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); - const int vlh = (vh << 2) & 0x04040404; - - v[l] = __vsubss4(vll, vlh); - } - - const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; - return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q4_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_K; // == 0 if QK_K == 256 - const int kqsx = k % QI4_K; // == k if QK_K == 256 - - const block_q4_K * bx0 = (const block_q4_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { - int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); - - const int * scales = (const int *) bxi->scales; - - const int ksc = k % (WARP_SIZE/8); - - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } -} - -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); - - const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; - return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q5_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_K; // == 0 if QK_K == 256 - const int kqsx = k % QI5_K; // == k if QK_K == 256 - - const block_q5_K * bx0 = (const block_q5_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR5_K*kqsx; - - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; - - const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); - - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { - int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); - - const int * scales = (const int *) bxi->scales; - - const int ksc = k % (WARP_SIZE/8); - - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } -} - -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; - const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; - return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q6_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI6_K; // == 0 if QK_K == 256 - const int kqsx = k % QI6_K; // == k if QK_K == 256 - - const block_q6_K * bx0 = (const block_q6_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR6_K*kqsx; - - const int ql = get_int_from_uint8(bxi->ql, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); - const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; - const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; - - const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; - const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); - - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { - int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; - - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); - } -} - -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); - - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; - const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; - return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); -} - -template -static __device__ __forceinline__ void mul_mat_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; - - const int & ncols_dst = ncols_y; - - const int row_dst_0 = blockIdx.x*mmq_y; - const int & row_x_0 = row_dst_0; - - const int col_dst_0 = blockIdx.y*mmq_x; - const int & col_y_0 = col_dst_0; - - int * tile_x_ql = nullptr; - half2 * tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - - allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); - - __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; - __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; - - float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; - - for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { - - load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, - threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); - -#pragma unroll - for (int ir = 0; ir < qr; ++ir) { - const int kqs = ir*WARP_SIZE + threadIdx.x; - const int kbxd = kqs / QI8_1; - -#pragma unroll - for (int i = 0; i < mmq_x; i += nwarps) { - const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses - - const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; - - const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; - tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); - } - -#pragma unroll - for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); - const int col_y_eff = min(col_y_0 + ids, ncols_y-1); - - // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; - if (need_sum) { - *dsi_dst = *dsi_src; - } else { - float * dfi_dst = (float *) dsi_dst; - *dfi_dst = __low2float(*dsi_src); - } - } - - __syncthreads(); - -// #pragma unroll // unrolling this loop causes too much register pressure - for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { -#pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { -#pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - sum[i/WARP_SIZE][j/nwarps] += vec_dot( - tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, - threadIdx.x + i, threadIdx.y + j, k); - } - } - } - - __syncthreads(); - } - } - -#pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { - const int col_dst = col_dst_0 + j + threadIdx.y; - - if (col_dst >= ncols_dst) { - return; - } - -#pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - const int row_dst = row_dst_0 + threadIdx.x + i; - - if (row_dst >= nrows_dst) { - continue; - } - - dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; - } - } -} - -static constexpr __device__ mmq_arch_config_t get_arch_config_device(mmq_config_t mmq_config) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - -#if defined(RDNA3) || defined(RDNA2) - return mmq_config.rdna2; -#else - return mmq_config.rdna1; -#endif // defined(RDNA3) || defined(RDNA2) - -#else - -#if __CUDA_ARCH__ >= CC_VOLTA - return mmq_config.ampere; -#else - return mmq_config.pascal; -#endif // __CUDA_ARCH__ >= CC_VOLTA - -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_0.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q4_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_0); - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.pascal.nwarps, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q4_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_1); - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_0.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q5_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_0); - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_1.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q5_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_1); - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q8_0.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q8_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q8_0); - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q2_K.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q2_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q2_K); - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.pascal.nwarps, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q3_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q3_K); - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q4_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_K); - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_K.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q5_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_K); - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q6_K.rdna2.nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q6_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A - constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q6_K); - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(get_arch_config_device); - GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define MMQ_SWITCH_CASE(type_suffix) \ - case GGML_TYPE_Q##type_suffix: if (row_diff % arch_config.y == 0) { \ - const bool need_check = false; \ - mul_mat_q##type_suffix<<>> \ - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \ - } else { \ - const bool need_check = true; \ - mul_mat_q##type_suffix<<>> \ - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \ - } break; \ void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, @@ -1454,12 +8,15 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne00 = src0->ne[0]; + const int64_t nb01 = src0->nb[1]; + const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; + const int64_t stride00 = nb01 / ggml_type_size(src0->type); int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; @@ -1468,73 +25,39 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - mmq_config_t mmq_config; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst}; switch (src0->type) { case GGML_TYPE_Q4_0: - mmq_config = MMQ_CONFIG_Q4_0; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q4_1: - mmq_config = MMQ_CONFIG_Q4_1; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q5_0: - mmq_config = MMQ_CONFIG_Q5_0; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q5_1: - mmq_config = MMQ_CONFIG_Q5_1; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q8_0: - mmq_config = MMQ_CONFIG_Q8_0; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q2_K: - mmq_config = MMQ_CONFIG_Q2_K; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q3_K: - mmq_config = MMQ_CONFIG_Q3_K; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q4_K: - mmq_config = MMQ_CONFIG_Q4_K; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q5_K: - mmq_config = MMQ_CONFIG_Q5_K; + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q6_K: - mmq_config = MMQ_CONFIG_Q6_K; - break; - default: - GGML_ASSERT(false); + mul_mat_q_case(args, stream); break; - } - - mmq_arch_config_t arch_config; - if (compute_capability >= CC_RDNA2) { - arch_config = mmq_config.rdna2; - } else if (compute_capability >= CC_OFFSET_AMD) { - arch_config = mmq_config.rdna1; - } else if (compute_capability >= CC_VOLTA) { - arch_config = mmq_config.ampere; - } else if (compute_capability >= MIN_CC_DP4A) { - arch_config = mmq_config.pascal; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (row_diff + arch_config.y - 1) / arch_config.y; - const int block_num_y = (src1_ncols + arch_config.x - 1) / arch_config.x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, arch_config.nwarps, 1); - - switch (src0->type) { - MMQ_SWITCH_CASE(4_0) - MMQ_SWITCH_CASE(4_1) - MMQ_SWITCH_CASE(5_0) - MMQ_SWITCH_CASE(5_1) - MMQ_SWITCH_CASE(8_0) - MMQ_SWITCH_CASE(2_K) - MMQ_SWITCH_CASE(3_K) - MMQ_SWITCH_CASE(4_K) - MMQ_SWITCH_CASE(5_K) - MMQ_SWITCH_CASE(6_K) default: GGML_ASSERT(false); break; diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 807817c4a715f..6744cce6d785f 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -1,4 +1,1304 @@ #include "common.cuh" +#include "vecdotq.cuh" + +#include +#include + +typedef void (*load_tiles_mmq_t)( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); +typedef void (*vec_dot_mmq_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0); + +struct tile_x_sizes { + int ql; + int dm; + int qh; + int sc; +}; + +// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row + +static constexpr __device__ int get_mmq_x_max_device() { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return 64; +#else +#if __CUDA_ARCH__ >= CC_VOLTA +#ifdef CUDA_USE_TENSOR_CORES + return MMQ_MAX_BATCH_SIZE; +#else + return 128; +#endif // CUDA_USE_TENSOR_CORES +#else + return 64; +#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +} + +// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +static constexpr __device__ int get_mmq_y_device(int mmq_x) { + return mmq_x >= 32 ? 128 : 64; +} +#else +#if __CUDA_ARCH__ >= CC_VOLTA +static constexpr __device__ int get_mmq_y_device(int mmq_x) { + return mmq_x >= 32 ? 128 : 64; +} +#else +static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) { + return 64; +} +#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + +#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0} +#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0} +#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0} +#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0} +#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0} +#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4} +#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4} +#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} + +#define GET_TILE_X_SIZES_BODY \ + return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \ + type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \ + type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \ + type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \ + type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \ + type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \ + type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \ + type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \ + type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \ + type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \ + tile_x_sizes{0, 0, 0, 0} + +static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) { + GET_TILE_X_SIZES_BODY; +} + +template +static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) { + GET_TILE_X_SIZES_BODY; +} + +// ------------------------------------------------------------ + +template static __device__ __forceinline__ void load_tiles_q4_0( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI4_0; + const int kqsx = threadIdx.x % QI4_0; + + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + const float * x_dmf = (const float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI4_1; + const int kqsx = threadIdx.x % QI4_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI5_0; + const int kqsx = threadIdx.x % QI5_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; + } +} + +template +static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + int u[2*VDR_Q5_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + + +template static __device__ __forceinline__ void load_tiles_q5_1( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI5_1; + const int kqsx = threadIdx.x % QI5_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } +} + +template +static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1; + + int u[2*VDR_Q5_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI8_0; + const int kqsx = threadIdx.x % QI8_0; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], + y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = threadIdx.x / QI2_K; + const int kqsx = threadIdx.x % QI2_K; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4); + + x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4)); + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kbx = k0 / QI2_K; + const int ky = (k0 % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; + + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + + const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + + const int kbx = threadIdx.x / QI3_K; + const int kqsx = threadIdx.x % QI3_K; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2)); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4); + + const int ksc = threadIdx.x % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + + x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc; + } +} + +template +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kbx = k0 / QI3_K; + const int ky = (k0 % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = __vsubss4(vll, vlh); + } + + const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( + v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = 0; // threadIdx.x / QI4_K + const int kqsx = threadIdx.x; // threadIdx.x % QI4_K + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( + &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = 0; // threadIdx.x / QI5_K + const int kqsx = threadIdx.x; // threadIdx.x % QI5_K + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx; + const int ky = QR5_K*kqsx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +template +static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); + + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0; + const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( + &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = 0; // threadIdx.x / QI6_K + const int kqsx = threadIdx.x; // threadIdx.x % QI6_K + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx; + const int ky = QR6_K*kqsx; + + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + + x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); + + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0; + const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( + &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + } + } +} + +// ------------------------------------------------------------------------------------------------------------------------------------- + +template +struct mmq_type_traits; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; +}; + +template +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*nwarps, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#else +#if __CUDA_ARCH__ >= CC_VOLTA + __launch_bounds__(WARP_SIZE*nwarps, 1) +#else + __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2) +#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +static __global__ void mul_mat_q( + const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, + const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device()) { + NO_DEVICE_CODE; + return; + } + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qr = ggml_cuda_type_traits::qr; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int mmq_y = get_mmq_y_device(mmq_x); + constexpr bool need_sum = mmq_type_traits::need_sum; + constexpr int vdr = mmq_type_traits::vdr; + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; + + constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); + + extern __shared__ char data_mul_mat_q[]; + int * tile_x_ql = (int *) data_mul_mat_q; + half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql); + int * tile_x_qh = (int *) (tile_x_dm + txs.dm); + int * tile_x_sc = (int *) (tile_x_qh + txs.qh); + int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE] + half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1]; + + const block_q8_1 * y = (const block_q8_1 *) yc; + + const int blocks_per_row_x = ne00 / qk; + const int blocks_per_col_y = ne10 / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ne1 = ne11; + + const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1; + + float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; + + for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { + + load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00); + +#pragma unroll + for (int kr = 0; kr < qr; ++kr) { + const int kqs = kr*WARP_SIZE + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_x; i0 += nwarps) { + const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd]; + + const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = __low2float(*dsi_src); + } + } + + __syncthreads(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { + vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0); + } + + __syncthreads(); + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; + + if (j >= ne1) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; + + if (need_check && i >= ne0) { + continue; + } + + dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } +} + +struct mmq_args { + const char * x; const char * y; float * dst; + int64_t ne00; int64_t ne01; int64_t stride00; + int64_t ne10; int64_t ne11; + int64_t ne0; +}; + +template +static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int mmq_y = get_mmq_y_host(cc, mmq_x); + + const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y; + const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); + const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int); + const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); + const int shmem = shmem_x + shmem_y; + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shmem_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + shmem_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + + if (args.ne01 % mmq_y == 0) { + const bool need_check = false; + mul_mat_q<<>> + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + } else { + const bool need_check = true; + mul_mat_q<<>> + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + } +} + +template +void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + const int cc = ggml_cuda_info().devices[id].cc; + + const int mmq_x_max = get_mmq_x_max_host(cc); + const int mmq_y = get_mmq_y_host(cc, mmq_x_max); + const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; + + int mmq_x_best = 0; + int nwaves_best = INT_MAX; + + for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) { + const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x; + const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm; + + if (nwaves < nwaves_best) { + mmq_x_best = mmq_x; + nwaves_best = nwaves; + } + } + + switch (mmq_x_best) { + case 8: + launch_mul_mat_q(args, stream); + break; + case 16: + launch_mul_mat_q(args, stream); + break; + case 24: + launch_mul_mat_q(args, stream); + break; + case 32: + launch_mul_mat_q(args, stream); + break; + case 40: + launch_mul_mat_q(args, stream); + break; + case 48: + launch_mul_mat_q(args, stream); + break; + case 56: + launch_mul_mat_q(args, stream); + break; + case 64: + launch_mul_mat_q(args, stream); + break; + case 72: + launch_mul_mat_q(args, stream); + break; + case 80: + launch_mul_mat_q(args, stream); + break; + case 88: + launch_mul_mat_q(args, stream); + break; + case 96: + launch_mul_mat_q(args, stream); + break; + case 104: + launch_mul_mat_q(args, stream); + break; + case 112: + launch_mul_mat_q(args, stream); + break; + case 120: + launch_mul_mat_q(args, stream); + break; + case 128: + launch_mul_mat_q(args, stream); + break; + default: + GGML_ASSERT(false); + break; + } +} + +#define DECL_MMQ_CASE(type) \ + template void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) \ + +extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); + +// ------------------------------------------------------------------------------------------------------------------------- void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index 65cc1bcaad697..5f056e91e5460 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -1,9 +1,47 @@ #include "mmvq.cuh" #include "vecdotq.cuh" -typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); + +static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : + type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : + type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : + type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : + type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : + type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : + type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : + type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : + type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : + type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : + type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : + type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : + type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : + type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : + type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : + type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : + type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : + type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : + type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : + nullptr; +} + +static constexpr __device__ int get_vdr_mmvq(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : + type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : + type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : + type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : + type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : + type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : + type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : + type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : + type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : + type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : + type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ : + 1; +} -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // tell the compiler to use as many registers as it wants, see nwarps definition below __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) @@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) constexpr int nwarps = 1; constexpr int rows_per_cuda_block = 1; @@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q( // partial sum for each thread float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; - const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q( for (int j = 0; j < ncols_y; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp[j][i] += vec_dot_q_cuda( - &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs); + tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); } } } @@ -81,12 +123,12 @@ static __global__ void mul_mat_vec_q( } } -template +template static void mul_mat_vec_q_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - GGML_ASSERT(ncols_x % qk == 0); + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); int id = ggml_cuda_get_device(); @@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda( switch (ncols_y) { case 1: - mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 2: - mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 3: - mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 4: - mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 5: - mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 6: - mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 7: - mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 8: - mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; default: GGML_ASSERT(false); @@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q3_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq2_xxs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq2_xs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq2_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq3_xxs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq1_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq1_m_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq4_xs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq3_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } void ggml_cuda_op_mul_mat_vec_q( diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu index d7f1034751cc1..6696a238476d8 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu index f3d8d2eda4991..dd070db2853f5 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu index 9beb05ca23597..54dcde6f52324 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu index 0c163dcba0613..4ec22f791912d 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu index 3980167b3db91..3c15bf7f0ef16 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu index fe099921d08db..7e61b5fdcdbca 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu index d4d5e7999e393..fdb15b580cff8 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu index f08b10c4d53c1..0f7c417d2c0c8 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu index e8c3f8adc2a1e..851f33c43f040 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu index c01416a13f6ac..763809cbeb44c 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu index 46615f281cb18..f2a276e50e5fa 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu index 72dcc1a2fe563..cb227f6f5ce1f 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu index 9fa8a377d0c6f..97ac0520c71d1 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu index 20ea86c6dcb89..c772b42634fe6 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu index ed815957cb212..5cb7430819e4e 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu index bbe9e6a1c8e61..98a709d171446 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu index d12a616996a9b..4f2f947ae81e6 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu index 1e901afcbbc99..11f96b6f65cee 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu index a3f98ce37456c..b39bdc0611c0d 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu index 1bae97243a062..bbd6a2c7f491c 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu index 7258e9775da09..9d84ff2b19175 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu index 08435c005e81c..bc8a5bff684ff 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu index 17864e8e9cdf3..a679100c83807 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu index 9239138c962ae..8f21bccf7f8da 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu index e387d9c1da3e2..858b00fd74191 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu index d69d3bbd683aa..0fc8011fac5fc 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu index 61a4788166ed7..261fdf623e098 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu index 89995080ac5f7..0fb8247383063 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu index 9e6a58dff6759..a9d9d089bd314 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu index 153cbfd86c82d..7d7b27920aa3e 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu index 09d5765582c8f..a092ee2d50957 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu index 3e3c91e68c219..db55927a19457 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu index 7b973058f4a89..c3c21cefae047 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu index a43a475d44439..35dd9f520802c 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu index 5b570c0a3df41..050c22ac7c6c7 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu index bf2cc684ef459..de4866c5e65ce 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu index 7428e45ea2621..57a10bc4be4a3 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu index 4aee830de1228..e0f08b46a7e35 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu index 36acb63191518..1c8e8a467a8aa 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu index a4090c3905507..cefed83fb9562 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu index 17b6b2d117b52..aede6e3588195 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu index 549e1cea1f379..1a1a92c788fbd 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu index 66bcd820f9a20..ad667473d110b 100644 --- a/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu index 15933a29977d0..c499f455da971 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu index 8aa7855839516..8286ebf373627 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu index bde3924fd7a28..4587868825d21 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu index 1708181c15874..d89103ce0c68f 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu index 30fa6fa4cebf0..bb75fd42ff17d 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu index 69673d50f4b3c..b1629817e79e3 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu index d8b2b2e18496e..d8657604dab80 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu index 01cce7ab50a7a..2e5bd2f1a3acc 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu index fd5563b395a5f..be5f302d9f1d4 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu index b13cc4a0c7e8d..8dd91cd72eb60 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu index 86f1fc63701f3..4cb791502a157 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu index 26e7df4be61f1..09dea426736e9 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu index e4fda8952c75f..0fbb607694f25 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu index bd15117b40db7..2aeab83b20d21 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu index cb6c6a7606aac..599415b494741 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu index 201b6641de102..e4f8e3083bb6b 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu index 6da57a44aea08..34d166527e93a 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu index 47623c9bff42c..4bebef45a37cb 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu index 82c6861d24f7a..326468da2fb24 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu index 24a80c2b042f5..511b58f4ecc72 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu index b95eaf7e18ccb..d9906d142e159 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu index 275f2efccd18a..f61c183abbaf7 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu index 3673f7fd55316..c10450fd29e76 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu index 2c4d59947b1eb..2d5cb195c41dc 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu index 2457cdf3fe77d..b384f34d7d921 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu index b3b411ed368b8..446e293b16edc 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu index b7f308a4d0b7e..6f430298899c7 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu index 7396866975b97..1cd8ba88fd650 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu index 708d03113e500..1ee2eab65a1c9 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu index df891be6031ce..2bc77816a5d4e 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu index f49b6d1f9ef8f..d55ced08bc940 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu index 1de92148b714f..8361e99c4e4a4 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu index 7a1ba7f8de4b7..7507a67c4c5e9 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu index 25493e4bac9ee..61f050b235ff2 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu index 3cd650c7bc3e4..d4a49d9c9912a 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu index 88ffa43d6f161..d146278976211 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu index 8c7bac6c2d7d9..e73f917a1f186 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu index a28f62e7b8202..d40825dfc21f0 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu index d39838b96fa9d..b5c6869f4ec42 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu index 834d40f6c8392..4e21b0ccaef16 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu index f7d54668b9d86..2eac321b370df 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu index 59e00ad83414c..f7d2c3b4e0a12 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu index 6e63893debd6b..a013f400bd33b 100644 --- a/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +++ b/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f32.cuh" diff --git a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu index ca356ad6c9ca0..2d94e65c28c29 100644 --- a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-wmma-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu index 430ee64eb29a5..c3d9df3c44313 100644 --- a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-wmma-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu index d421d17ccc5fd..bb680e401f7da 100644 --- a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-wmma-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu index deacd5f58eec9..073f71b1f3e26 100644 --- a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-wmma-f16.cuh" diff --git a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu index 282896733473f..d30710c5fa21b 100644 --- a/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ b/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu @@ -1,4 +1,4 @@ -// This file has been autogenerated by generate-variants.py, do not edit manually. +// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-wmma-f16.cuh" diff --git a/ggml-cuda/template-instances/generate_cu_files.py b/ggml-cuda/template-instances/generate_cu_files.py index ee5b460e07986..ea58d09680231 100755 --- a/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml-cuda/template-instances/generate_cu_files.py @@ -20,6 +20,18 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n" +TYPES_MMQ = [ + "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", + "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K" +] + +SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE({type}); +""" + def get_short_name(long_quant_name): return long_quant_name.replace("GGML_TYPE_", "").lower() @@ -57,3 +69,7 @@ def get_head_sizes(type_k, type_v): if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance continue f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size)) + +for type in TYPES_MMQ: + with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: + f.write(SOURCE_MMQ.format(type=type)) diff --git a/ggml-cuda/template-instances/mmq-instance-q2_k.cu b/ggml-cuda/template-instances/mmq-instance-q2_k.cu new file mode 100644 index 0000000000000..6415369dc1d95 --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q2_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q2_K); diff --git a/ggml-cuda/template-instances/mmq-instance-q3_k.cu b/ggml-cuda/template-instances/mmq-instance-q3_k.cu new file mode 100644 index 0000000000000..ffb6213af83ea --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q3_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q3_K); diff --git a/ggml-cuda/template-instances/mmq-instance-q4_0.cu b/ggml-cuda/template-instances/mmq-instance-q4_0.cu new file mode 100644 index 0000000000000..0c0b0c8a8ed22 --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q4_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_0); diff --git a/ggml-cuda/template-instances/mmq-instance-q4_1.cu b/ggml-cuda/template-instances/mmq-instance-q4_1.cu new file mode 100644 index 0000000000000..ee67f6942a8fc --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q4_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_1); diff --git a/ggml-cuda/template-instances/mmq-instance-q4_k.cu b/ggml-cuda/template-instances/mmq-instance-q4_k.cu new file mode 100644 index 0000000000000..9eeb3cd7f3cc1 --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q4_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml-cuda/template-instances/mmq-instance-q5_0.cu b/ggml-cuda/template-instances/mmq-instance-q5_0.cu new file mode 100644 index 0000000000000..cc57fb9753c9e --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q5_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_0); diff --git a/ggml-cuda/template-instances/mmq-instance-q5_1.cu b/ggml-cuda/template-instances/mmq-instance-q5_1.cu new file mode 100644 index 0000000000000..721ac790c44f4 --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q5_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_1); diff --git a/ggml-cuda/template-instances/mmq-instance-q5_k.cu b/ggml-cuda/template-instances/mmq-instance-q5_k.cu new file mode 100644 index 0000000000000..a2e90ffd5d0aa --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q5_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_K); diff --git a/ggml-cuda/template-instances/mmq-instance-q6_k.cu b/ggml-cuda/template-instances/mmq-instance-q6_k.cu new file mode 100644 index 0000000000000..470938fef8a05 --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q6_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q6_K); diff --git a/ggml-cuda/template-instances/mmq-instance-q8_0.cu b/ggml-cuda/template-instances/mmq-instance-q8_0.cu new file mode 100644 index 0000000000000..974477bbb73a8 --- /dev/null +++ b/ggml-cuda/template-instances/mmq-instance-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q8_0); diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index df9752390509d..b9573a7c7d053 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -566,9 +566,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( } static __device__ __forceinline__ float vec_dot_q4_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx; int v[VDR_Q4_0_Q8_1_MMVQ]; int u[2*VDR_Q4_0_Q8_1_MMVQ]; @@ -585,9 +585,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( static __device__ __forceinline__ float vec_dot_q4_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx; int v[VDR_Q4_1_Q8_1_MMVQ]; int u[2*VDR_Q4_1_Q8_1_MMVQ]; @@ -603,9 +603,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( } static __device__ __forceinline__ float vec_dot_q5_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx; int vl[VDR_Q5_0_Q8_1_MMVQ]; int vh[VDR_Q5_0_Q8_1_MMVQ]; @@ -623,9 +623,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( } static __device__ __forceinline__ float vec_dot_q5_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx; int vl[VDR_Q5_1_Q8_1_MMVQ]; int vh[VDR_Q5_1_Q8_1_MMVQ]; @@ -643,9 +643,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( } static __device__ __forceinline__ float vec_dot_q8_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx; int v[VDR_Q8_0_Q8_1_MMVQ]; int u[VDR_Q8_0_Q8_1_MMVQ]; @@ -660,9 +660,9 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q2_K * bq2_K = (const block_q2_K *) vbq; + const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx; const int bq8_offset = QR2_K * (iqs / QI8_1); const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); @@ -683,9 +683,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( } static __device__ __forceinline__ float vec_dot_q3_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q3_K * bq3_K = (const block_q3_K *) vbq; + const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx; const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); @@ -710,9 +710,9 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( } static __device__ __forceinline__ float vec_dot_q4_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q4_K * bq4_K = (const block_q4_K *) vbq; + const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx; int v[2]; int u[2*QR4_K]; @@ -756,9 +756,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q5_K * bq5_K = (const block_q5_K *) vbq; + const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx; int vl[2]; int vh[2]; @@ -802,9 +802,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( } static __device__ __forceinline__ float vec_dot_q6_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q6_K * bq6_K = (const block_q6_K *) vbq; + const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx; const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); @@ -828,8 +828,8 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( } static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx; #if QR2_XXS == 8 const int ib32 = iqs; @@ -872,9 +872,9 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( } static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx; const int ib32 = iqs; const uint16_t * q2 = bq2->qs + 4*ib32; @@ -911,9 +911,9 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( // TODO static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_iq2_s * bq2 = (const block_iq2_s *) vbq; + const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx; const int ib32 = iqs; const int8_t * q8 = bq8_1[ib32].qs; @@ -951,9 +951,9 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( } static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; + const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq + kbx; const int ib32 = iqs; const uint8_t * q3 = bq2->qs + 8*ib32; @@ -981,9 +981,9 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( // TODO: don't use lookup table for signs static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_iq3_s * bq2 = (const block_iq3_s *) vbq; + const block_iq3_s * bq2 = (const block_iq3_s *) vbq + kbx; const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; @@ -1008,8 +1008,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( } static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - const block_iq1_s * bq1 = (const block_iq1_s *) vbq; + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx; const int ib32 = iqs; int sumi = 0; @@ -1039,8 +1039,8 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( } static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - const block_iq1_m * bq1 = (const block_iq1_m *) vbq; + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx; const int ib32 = iqs; int sumi[2] = {0, 0}; @@ -1094,9 +1094,9 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4 #endif static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq4_nl * bq = (const block_iq4_nl *) vbq; + const block_iq4_nl * bq = (const block_iq4_nl *) vbq + kbx; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs; @@ -1128,10 +1128,10 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( } static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; + const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx; const uint8_t * values = (const uint8_t *)kvalues_iq4nl; // iqs is 0...7 @@ -1149,6 +1149,6 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( } return d * (sumi1 + sumi2); #else - return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs); + return vec_dot_iq4_xs_q8_1(vbq, bq8_1, kbx, iqs); #endif }