Skip to content

Commit

Permalink
CUDA: refactor mmq, dmmv, mmvq (ggerganov#7716)
Browse files Browse the repository at this point in the history
* CUDA: refactor mmq, dmmv, mmvq

* fix out-of-bounds write

* struct for qk, qr, qi

* fix cmake build

* mmq_type_traits
  • Loading branch information
JohannesGaessler committed Jun 5, 2024
1 parent 2b33896 commit 7d1a378
Show file tree
Hide file tree
Showing 112 changed files with 1,783 additions and 1,767 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ if (LLAMA_CUDA)
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})

add_compile_definitions(GGML_USE_CUDA)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
Expand Down Expand Up @@ -588,6 +590,8 @@ if (LLAMA_HIPBLAS)
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})

add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)

Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ ifdef LLAMA_CUBLAS
endif

OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/mmq*.cu))
ifdef LLAMA_CUDA_FA_ALL_QUANTS
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
else
Expand Down
6 changes: 6 additions & 0 deletions ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ typedef sycl::half2 ggml_half2;
#define QI1_S (QK_K / (4*QR1_S))
#define QR1_S 8

#define QI1_M (QK_K / (4*QR1_M))
#define QR1_M 8

#define QI4_NL (QK4_NL / (4*QR4_NL))
#define QR4_NL 2

#define QI4_XS (QK_K / (4*QR4_XS))
#define QR4_XS 8

#define QI3_S (QK_K / (4*QR3_S))
#define QR3_S 8

#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP

#define QK4_0 32
Expand Down
84 changes: 9 additions & 75 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -633,88 +633,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {

// cuda split buffer

static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t min_compute_capability = INT_MAX;
int64_t max_compute_capability = INT_MIN;
static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t row_rounding = 0;
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
min_compute_capability = ggml_cuda_info().devices[id].cc;
}
if (max_compute_capability < ggml_cuda_info().devices[id].cc) {
max_compute_capability = ggml_cuda_info().devices[id].cc;
}
if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
continue;
}
}

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
return max_compute_capability >= CC_RDNA2 ? 128 : 32;
case GGML_TYPE_Q3_K:
return min_compute_capability < CC_RDNA2 ? 128 : 64;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
}
#else
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
default:
GGML_ASSERT(false);
const int cc = ggml_cuda_info().devices[id].cc;
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
}
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return row_rounding;
}

static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
const int64_t nrows = ggml_nrows(tensor);
const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
const int64_t rounding = get_row_rounding(tensor_split);

*row_low = id == 0 ? 0 : nrows*tensor_split[id];
*row_low -= *row_low % rounding;
Expand Down Expand Up @@ -1499,7 +1433,7 @@ static void ggml_cuda_op_mul_mat(
// for multi GPU, get the row boundaries from tensor split
// and round to mul_mat_q tile sizes
if (split) {
const int64_t rounding = get_row_rounding(src0->type, tensor_split);
const int64_t rounding = get_row_rounding(tensor_split);

if (id != 0) {
dev[id].row_low = ne01*tensor_split[id];
Expand Down
157 changes: 156 additions & 1 deletion ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
#endif

#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available

#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses

Expand Down Expand Up @@ -484,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope(
return powf(base, exph);
}

template <ggml_type type>
struct ggml_cuda_type_traits;

template<>
struct ggml_cuda_type_traits<GGML_TYPE_F16> {
static constexpr int qk = 1;
static constexpr int qr = 1;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
static constexpr int qk = QK4_0;
static constexpr int qr = QR4_0;
static constexpr int qi = QI4_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
static constexpr int qk = QK4_1;
static constexpr int qr = QR4_1;
static constexpr int qi = QI4_1;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
static constexpr int qk = QK5_0;
static constexpr int qr = QR5_0;
static constexpr int qi = QI5_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
static constexpr int qk = QK5_1;
static constexpr int qr = QR5_1;
static constexpr int qi = QI5_1;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qk = QK8_0;
static constexpr int qr = QR8_0;
static constexpr int qi = QI8_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_K;
static constexpr int qi = QI2_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_K;
static constexpr int qi = QI3_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_K;
static constexpr int qi = QI4_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR5_K;
static constexpr int qi = QI5_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR6_K;
static constexpr int qi = QI6_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XXS;
static constexpr int qi = QI2_XXS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XS;
static constexpr int qi = QI2_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_S;
static constexpr int qi = QI2_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_XXS;
static constexpr int qi = QI3_XXS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_S;
static constexpr int qi = QI1_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_M;
static constexpr int qi = QI1_M;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
static constexpr int qi = QI4_NL;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_S;
static constexpr int qi = QI3_S;
};

static int get_mmq_x_max_host(const int cc) {
#ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
#endif // CUDA_USE_TENSOR_CORES
}

// Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc, const int mmq_x) {
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
}

//////////////////////

struct ggml_cuda_device_info {
Expand Down
Loading

0 comments on commit 7d1a378

Please sign in to comment.