Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ set(CMAKE_WARN_UNUSED_CLI YES)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0)

if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
Expand Down
3 changes: 3 additions & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
{ "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", },
{ "IQ5_K", LLAMA_FTYPE_MOSTLY_IQ5_K, " 5.5 bpw non-linear quantization", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
Expand Down
10 changes: 8 additions & 2 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ extern "C" {
GGML_TYPE_IQ1_BN = 34,
GGML_TYPE_IQ2_BN = 35,
GGML_TYPE_Q8_K64 = 36,
GGML_TYPE_IQ4_K = 37,
GGML_TYPE_IQ2_K = 37,
GGML_TYPE_IQ3_K = 38,
GGML_TYPE_IQ4_K = 39,
GGML_TYPE_IQ5_K = 40,
GGML_TYPE_COUNT,
};

Expand Down Expand Up @@ -436,7 +439,10 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K = 30, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_K = 30, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors
};

// available tensor operations:
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ if (GGML_CUDA)

find_package(CUDAToolkit)

set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0)

if (CUDAToolkit_FOUND)
message(STATUS "CUDA found")

Expand Down
47 changes: 47 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ typedef sycl::half2 ggml_half2;
#define QI4_XS (QK_K / (4*QR4_XS))
#define QR4_XS 2

#define QI5_XS (QK_K / (4*QR5_XS))
#define QR5_XS 2

#define QI3_S (QK_K / (4*QR3_S))
#define QR3_S 4

Expand Down Expand Up @@ -445,6 +448,24 @@ typedef struct {
} block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");

typedef struct {
ggml_half d;
uint16_t extra;
uint8_t scales[QK_K/32];
uint8_t qs[QK_K/4];
} block_iq2_k;
static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/32 + QK_K/4, "wrong iq2_k block size/padding");

typedef struct {
ggml_half d;
uint16_t extra;
uint16_t scales_h;
uint8_t scales_l[QK_K/32];
uint8_t qs[QK_K/4];
uint8_t qh[QK_K/8];
} block_iq3_k;
static_assert(sizeof(block_iq3_k) == sizeof(ggml_half) + 2*sizeof(uint16_t) + QK_K/32 + QK_K/4 + QK_K/8, "wrong iq3_k block size/padding");

typedef struct {
ggml_half d;
uint16_t extra;
Expand All @@ -454,6 +475,17 @@ typedef struct {
} block_iq4_k;
static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding");

typedef struct {
ggml_half d;
uint16_t extra;
uint8_t scales_h[QK_K/64];
uint8_t scales_l[QK_K/32];
uint8_t qs[QK_K/2];
uint8_t qh[QK_K/8];
} block_iq5_k;
static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding");


#endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL

Expand Down Expand Up @@ -1885,10 +1917,25 @@ GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S)
GGML_TABLE_END()
#endif

GGML_TABLE_BEGIN(int8_t, iq2nl_values, 8)
-31, -13, 1, 17, -26, -8, 6, 22
GGML_TABLE_END()

GGML_TABLE_BEGIN(int8_t, iq3nl_values, 16)
-63, -40, -23, -10, 1, 13, 28, 47,
-59, -36, -19, -6, 5, 17, 32, 51,
GGML_TABLE_END()

GGML_TABLE_BEGIN(int8_t, iq4k_values, 32)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
-123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117
GGML_TABLE_END()

GGML_TABLE_BEGIN(int8_t, iq5nl_values, 64)
-126, -114, -103, -92, -83, -74, -65, -57, -50, -43, -36, -30, -24, -18, -12, -6, -1, 5, 11, 17, 23, 29, 36, 43, 51, 59, 68, 77, 87, 97, 109, 121,
-124, -112, -101, -90, -81, -72, -63, -55, -48, -41, -34, -28, -22, -16, -10, -4, 1, 7, 13, 19, 25, 31, 38, 45, 53, 61, 70, 79, 89, 99, 111, 123,
GGML_TABLE_END()


#endif // GGML_COMMON_IMPL
#endif // GGML_COMMON_IMPL
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2753,7 +2753,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
return true;
Expand Down
21 changes: 21 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -669,13 +669,34 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR5_XS;
static constexpr int qi = QI5_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
Expand Down
110 changes: 110 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,86 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int i = blockIdx.x;
const block_iq5_k * x = (const block_iq5_k *) vx;

const int tid = threadIdx.x;
int ib64 = tid/8; // 0...3
int il = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 64*ib64 + 2*il;
const float d = (float)x[i].d;
const float dl1 = d * (((x[i].scales_l[2*ib64+0] & 0xf) | ((x[i].scales_h[ib64] << 4) & 0x30)) - 32);
const float dl2 = d * (((x[i].scales_l[2*ib64+0] >> 4) | ((x[i].scales_h[ib64] << 2) & 0x30)) - 32);
const float dl3 = d * (((x[i].scales_l[2*ib64+1] & 0xf) | ((x[i].scales_h[ib64] >> 0) & 0x30)) - 32);
const float dl4 = d * (((x[i].scales_l[2*ib64+1] >> 4) | ((x[i].scales_h[ib64] >> 2) & 0x30)) - 32);
const uint8_t * qs = x[i].qs + 32*ib64 + 2*il;
const uint8_t * qh = x[i].qh + 2*il;
const uint8_t extra = x[i].extra >> 4*(ib64%4);
for (int j = 0; j < 2; ++j) {
const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4);
y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)];
y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)];
y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)];
y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)];
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int i = blockIdx.x;
const block_iq2_k * x = (const block_iq2_k *) vx;

const int tid = threadIdx.x;
int ib128 = tid/16; // 0 or 1
int il = tid%16; // 0...15
dst_t * y = yy + i*QK_K + 128*ib128 + 2*il;
const float d = (float)x[i].d * 1.025f; //1.0325f;
const float dl1 = d * (2*((x[i].scales[4*ib128+0] >> 4*(il/8)) & 0xf) - 15);
const float dl2 = d * (2*((x[i].scales[4*ib128+1] >> 4*(il/8)) & 0xf) - 15);
const float dl3 = d * (2*((x[i].scales[4*ib128+2] >> 4*(il/8)) & 0xf) - 15);
const float dl4 = d * (2*((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 15);
const uint8_t * qs = x[i].qs + 32*ib128 + 2*il;
const int16_t extra = x[i].extra >> (8*ib128 + (il/8));
for (int j = 0; j < 2; ++j) {
y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)];
y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)];
y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)];
y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)];
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int i = blockIdx.x;
const block_iq3_k * x = (const block_iq3_k *) vx;

const int tid = threadIdx.x;
int ib128 = tid/16; // 0 or 1
int il = tid%16; // 0...15
dst_t * y = yy + i*QK_K + 128*ib128 + 2*il;
const float d = (float)x[i].d * 1.01f; //1.0125f;
const uint16_t sh = x[i].scales_h >> (8*ib128 + (il/8));
const float dl1 = d * ((2*((x[i].scales_l[4*ib128+0] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x01) ? -1 : 1));
const float dl2 = d * ((2*((x[i].scales_l[4*ib128+1] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x04) ? -1 : 1));
const float dl3 = d * ((2*((x[i].scales_l[4*ib128+2] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x10) ? -1 : 1));
const float dl4 = d * ((2*((x[i].scales_l[4*ib128+3] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x40) ? -1 : 1));
const uint8_t * qs = x[i].qs + 32*ib128 + 2*il;
const uint8_t * qh = x[i].qh + 2*il;
const int16_t extra = x[i].extra >> (8*ib128 + (il/8));
for (int j = 0; j < 2; ++j) {
const uint8_t h = qh[j] >> (4*(ib128%2));
y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)];
y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)];
y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)];
y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)];
}
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
Expand Down Expand Up @@ -672,12 +752,30 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq3_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y);
}

template <typename src_t, typename dst_t>
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -742,8 +840,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ2_K:
return dequantize_row_iq2_k_cuda;
case GGML_TYPE_IQ3_K:
return dequantize_row_iq3_k_cuda;
case GGML_TYPE_IQ4_K:
return dequantize_row_iq4_k_cuda;
case GGML_TYPE_IQ5_K:
return dequantize_row_iq5_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
Expand Down Expand Up @@ -795,8 +899,14 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ2_K:
return dequantize_row_iq2_k_cuda;
case GGML_TYPE_IQ3_K:
return dequantize_row_iq3_k_cuda;
case GGML_TYPE_IQ4_K:
return dequantize_row_iq4_k_cuda;
case GGML_TYPE_IQ5_K:
return dequantize_row_iq5_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
Expand Down
Loading