Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow convert.py to convert directly to q8_0 #2753

Merged
merged 4 commits into from
Aug 26, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 94 additions & 25 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gguf
import argparse
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import copy
import enum
import faulthandler
Expand All @@ -17,6 +18,7 @@
import signal
import struct
import sys
import time
import zipfile
import numpy as np

Expand All @@ -37,6 +39,7 @@
ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]

DEFAULT_CONCURRENCY = 8
#
# data types
#
Expand All @@ -50,7 +53,13 @@ class UnquantizedDataType:
DT_I32 = UnquantizedDataType('I32')
DT_BF16 = UnquantizedDataType('BF16')

DataType = Union[UnquantizedDataType]
@dataclass(frozen=True)
class QuantizedDataType:
name: str

DT_Q8_0 = QuantizedDataType('Q8_0')

DataType = Union[UnquantizedDataType, QuantizedDataType]

DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
DT_BF16: np.dtype(np.uint16),
Expand All @@ -73,8 +82,9 @@ class UnquantizedDataType:
# TODO: rename to LLAMAFileType
# TODO: move to `gguf.py`
class GGMLFileType(enum.IntEnum):
AllF32 = 0
MostlyF16 = 1 # except 1d tensors
AllF32 = 0
MostlyF16 = 1 # except 1d tensors
MostlyQ8_0 = 7 # except 1d tensors

def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
if len(tensor.shape) == 1:
Expand All @@ -84,6 +94,8 @@ def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
return DT_F32
elif self == GGMLFileType.MostlyF16:
return DT_F16
elif self == GGMLFileType.MostlyQ8_0:
return DT_Q8_0
else:
raise ValueError(self)

Expand Down Expand Up @@ -391,7 +403,10 @@ def __init__(self, ndarray: NDArray) -> None:
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]

def astype(self, data_type: DataType) -> Tensor:
dtype = DATA_TYPE_TO_NUMPY[data_type]
if data_type == DT_Q8_0:
dtype = DATA_TYPE_TO_NUMPY[DT_F32]
else:
dtype = DATA_TYPE_TO_NUMPY[data_type]
if self.data_type == DT_BF16:
self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype))
Expand Down Expand Up @@ -455,7 +470,7 @@ class LazyTensor:

def load(self) -> Tensor:
ret = self._load()
assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
assert ret.data_type == self.data_type or (self.data_type is DT_Q8_0 and ret.data_type is DT_F32), (self.data_type, ret.data_type, self.description)
return ret

def astype(self, data_type: DataType) -> 'LazyTensor':
Expand Down Expand Up @@ -699,23 +714,32 @@ def lazy_load_file(path: Path) -> ModelPlus:
In = TypeVar('In')
Out = TypeVar('Out')

def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
'''Parallel map, but with backpressure. If the caller doesn't call `next`
fast enough, this will stop calling `func` at some point rather than
letting results pile up in memory. Specifically, there is a max of one
output value buffered per thread.'''
with concurrent.futures.ThreadPoolExecutor() as executor:
iterable = iter(iterable)
with factory(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = []
items_rev = list(iterable)[::-1]
for i in range(min(concurrency, len(items_rev))):
futures.append(executor.submit(func, items_rev.pop()))
done = False
for _ in range(concurrency):
try:
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break

while futures:
result = futures.pop(0).result()
if items_rev:
futures.append(executor.submit(func, items_rev.pop()))
while not done and len(futures) < concurrency:
try:
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break
yield result


def check_vocab_size(params: Params, vocab: Vocab) -> None:
if params.n_vocab != vocab.vocab_size:
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
Expand All @@ -732,6 +756,24 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None:
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
raise Exception(msg)

#### Mini Q8_0 quantization in Python
QK8_0 = 32
BLOCK_Q8_0 = np.dtype([('d', '<f2'), ('qs', 'i1', (QK8_0,))])
def quantize_array_q8_0(arr):
assert arr.size % QK8_0 == 0 and arr.size != 0, f'Bad array size {arr.size}'
assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
n_blocks = arr.size // QK8_0
blocks = arr.reshape((n_blocks, QK8_0))
return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = BLOCK_Q8_0)

# Much faster implementation of block quantization contributed by @Cebtenzzre
def quantize_blocks_q8_0(blocks):
d = abs(blocks).max(axis = 1) / np.float32(127)
with np.errstate(divide = 'ignore'):
qs = (blocks / d[:, None]).round()
qs[d == 0] = 0
yield from zip(np.float16(d), qs)


class OutputFile:
def __init__(self, fname_out: Path) -> None:
Expand Down Expand Up @@ -777,9 +819,16 @@ def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
n_elements = 1
for dim in tensor.shape:
n_elements *= dim
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
data_nbytes = n_elements * data_type.itemsize
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
if tensor.data_type == DT_Q8_0:
assert n_elements > 0 and n_elements % QK8_0 == 0, f'Cannot quantize as Q8_0, {n_elements} not a multiple of block size {QK8_0}'
data_type= BLOCK_Q8_0
raw_dtype = gguf.GGMLQuantizationType.Q8_0
data_nbytes = n_elements + (n_elements // QK8_0) * 2
else:
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
data_nbytes = n_elements * data_type.itemsize
raw_dtype = None
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)

def write_meta(self) -> None:
self.gguf.write_header_to_file()
Expand All @@ -805,7 +854,19 @@ def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
of.close()

@staticmethod
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray):
name, lazy_tensor = item
tensor = lazy_tensor.load().to_ggml()
return (lazy_tensor.data_type, tensor.ndarray)

@staticmethod
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
if item[0] == DT_Q8_0:
return quantize_array_q8_0(item[1])
return item[1]

@staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab)

of = OutputFile(fname_out)
Expand All @@ -821,16 +882,19 @@ def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -
of.write_meta()
of.write_tensor_info()

def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
name, lazy_tensor = item
return lazy_tensor.load().to_ggml().ndarray

# tensor data
ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
if ftype == GGMLFileType.MostlyQ8_0:
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
else:
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays)

start = time.time()
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
elapsed = time.time() - start
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
padi = len(str(len(model)))
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
of.gguf.write_tensor_data(ndarray)

of.close()
Expand All @@ -842,6 +906,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
return GGMLFileType.AllF32
if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
return GGMLFileType.MostlyF16
if output_type_str == "q8_0":
return GGMLFileType.MostlyQ8_0

name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}

Expand Down Expand Up @@ -993,6 +1059,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
namestr = {
GGMLFileType.AllF32: "f32",
GGMLFileType.MostlyF16: "f16",
GGMLFileType.MostlyQ8_0:"q8_0",
}[file_type]
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
if ret in model_paths:
Expand All @@ -1016,12 +1083,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
args = parser.parse_args(args_in)

if args.dump_single:
Expand All @@ -1043,6 +1111,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
params.ftype = {
"f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype]

print(f"params = {params}")
Expand Down Expand Up @@ -1074,7 +1143,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")

OutputFile.write_all(outfile, params, model, vocab)
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
print(f"Wrote {outfile}")


Expand Down