Skip to content

Commit

Permalink
Add support for Q2_K, Q3_K, Q5_K
Browse files Browse the repository at this point in the history
  • Loading branch information
99991 committed Apr 21, 2024
1 parent 351e9b2 commit a417edb
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 11 deletions.
160 changes: 158 additions & 2 deletions gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
GGML_TYPES = {
"F32": 0,
"Q8_0": 8,
"Q2_K": 10,
"Q3_K": 11,
"Q4_K": 12,
"Q5_K": 13,
"Q6_K": 14,
}

Q8_0_BLOCK_SIZE = 2 + 32
Q4_K_BLOCK_SIZE = 144
Q6_K_BLOCK_SIZE = 210
Q2_K_BLOCK_SIZE = 256 // 16 + 256 // 4 + 2 + 2
Q3_K_BLOCK_SIZE = 256 // 8 + 256 // 4 + 12 + 2
Q4_K_BLOCK_SIZE = 2 + 2 + 12 + 256 // 2
Q5_K_BLOCK_SIZE = 2 + 2 + 12 + 256 // 8 + 256 // 2
Q6_K_BLOCK_SIZE = 256 // 2 + 256 // 4 + 256 // 16 + 2

DATA_TYPES = {
4: "uint32",
Expand Down Expand Up @@ -96,6 +102,83 @@ def load_gguf(f):

return info, tensorinfo

def dequantize_q2_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74
num_blocks = len(data) // Q2_K_BLOCK_SIZE

data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, Q2_K_BLOCK_SIZE // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, Q2_K_BLOCK_SIZE)

dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
qs = data_u8[:, 16:80].reshape(num_blocks, 64)

tmp = np.stack([
qs[:, 00:16] >> 0,
qs[:, 16:32] >> 0,
qs[:, 00:16] >> 2,
qs[:, 16:32] >> 2,
qs[:, 00:16] >> 4,
qs[:, 16:32] >> 4,
qs[:, 00:16] >> 6,
qs[:, 16:32] >> 6,
qs[:, 32:48] >> 0,
qs[:, 48:64] >> 0,
qs[:, 32:48] >> 2,
qs[:, 48:64] >> 2,
qs[:, 32:48] >> 4,
qs[:, 48:64] >> 4,
qs[:, 32:48] >> 6,
qs[:, 48:64] >> 6,
], axis=1)

return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)

def dequantize_q3_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95
num_blocks = len(data) // Q3_K_BLOCK_SIZE

data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, Q3_K_BLOCK_SIZE // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, Q3_K_BLOCK_SIZE)

d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
bits = 4 ^ (bits << 2)
qs = data_u8[:, 32:32 + 64].astype(np.int16)
a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
scales[:, 0] = (a & 15) | ((c & 3) << 4)
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)

return d * (scales - 32) * np.stack([
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1)

def dequantize_q4_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929
Expand All @@ -121,6 +204,61 @@ def dequantize_q4_k(data):
# Dequantize final weights using scales and offsets
return factors * qs2 - offsets

def dequantize_q5_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138
num_blocks = len(data) // Q5_K_BLOCK_SIZE

data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, Q5_K_BLOCK_SIZE // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, Q5_K_BLOCK_SIZE)

d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)

bits = np.unpackbits(qh, axis=-1, bitorder="little")

qs_hi_4 = qs >> 4
qs_lo_4 = qs & 15

scales_lo_6 = scales[:, :8] & 63
scales_hi_6 = scales[:, :8] >> 6
scales_lo_4 = scales[:, 8:] & 15
scales_hi_4 = scales[:, 8:] >> 4

m1 = dmin * scales_lo_6[:, 4]
m2 = dmin * scales_lo_6[:, 5]
m3 = dmin * scales_lo_6[:, 6]
m4 = dmin * scales_lo_6[:, 7]
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))

d1 = d * scales_lo_6[:, 0]
d2 = d * scales_lo_6[:, 1]
d3 = d * scales_lo_6[:, 2]
d4 = d * scales_lo_6[:, 3]
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))

return np.concatenate([
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1)

def dequantize_q6_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
Expand Down Expand Up @@ -196,12 +334,30 @@ def load_gguf_tensor(f, tensorinfo, name):

values = dequantize_q8_0(data)

elif ggml_type == GGML_TYPES["Q2_K"]:
size = num_elements * Q2_K_BLOCK_SIZE // 256
data = f.read(size)

values = dequantize_q2_k(data)

elif ggml_type == GGML_TYPES["Q3_K"]:
size = num_elements * Q3_K_BLOCK_SIZE // 256
data = f.read(size)

values = dequantize_q3_k(data)

elif ggml_type == GGML_TYPES["Q4_K"]:
size = num_elements * Q4_K_BLOCK_SIZE // 256
data = f.read(size)

values = dequantize_q4_k(data)

elif ggml_type == GGML_TYPES["Q5_K"]:
size = num_elements * Q5_K_BLOCK_SIZE // 256
data = f.read(size)

values = dequantize_q5_k(data)

elif ggml_type == GGML_TYPES["Q6_K"]:
size = num_elements * Q6_K_BLOCK_SIZE // 256
data = f.read(size)
Expand Down
28 changes: 19 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def translate_name(name):
return name

def main():
import os
import time
from safetensors.torch import load_file

Expand All @@ -37,11 +38,23 @@ def main():
print(f"{key:30} {value.shape}")
print()

for filename in [
"data/TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
"data/TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q8_0.gguf",
]:
with open(filename, "r+b") as f:
gguf_dir = "data/TinyLlama-1.1B-Chat-v1.0-GGUF/"

max_mses = {
"tinyllama-1.1b-chat-v1.0.Q2_K.gguf": 0.0002846,
"tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf": 7.652e-05,
"tinyllama-1.1b-chat-v1.0.Q3_K_M.gguf": 7.652e-05,
"tinyllama-1.1b-chat-v1.0.Q3_K_S.gguf": 7.652e-05,
"tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf": 1.705e-05,
"tinyllama-1.1b-chat-v1.0.Q4_K_S.gguf": 1.705e-05,
"tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf": 4.371e-06,
"tinyllama-1.1b-chat-v1.0.Q5_K_S.gguf": 4.371e-06,
"tinyllama-1.1b-chat-v1.0.Q6_K.gguf": 1.090e-06,
"tinyllama-1.1b-chat-v1.0.Q8_0.gguf": 1.034e-07,
}

for filename, max_mse in max_mses.items():
with open(os.path.join(gguf_dir, filename), "r+b") as f:
# also works with mmap (at least on Linux)
#import mmap
#f = mmap.mmap(f.fileno(), 0)
Expand Down Expand Up @@ -85,10 +98,7 @@ def main():

print(f"MSE {mse:.10f} {name:30} ggml_type {ggml_type:2} {str(shape):13} {ms:7.3f} ms")

if "Q8_0" in filename:
assert mse < 2e-6
else:
assert mse < 2e-5
assert mse < max_mse, f"Error too large, should be less than {max_mse}, but is {mse} for {filename}"

print("Tests passed :)")

Expand Down

0 comments on commit a417edb

Please sign in to comment.