Skip to content

Commit 8047aa1

Browse files
committed
Replay changes from #3871
Credit to @cebtenzzre for that pull
1 parent b8c80df commit 8047aa1

File tree

2 files changed

+54
-31
lines changed

2 files changed

+54
-31
lines changed

gguf-py/gguf/gguf_writer.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import struct
66
import tempfile
77
from io import BufferedWriter
8-
from typing import Any, BinaryIO, Sequence
8+
from enum import Enum, auto
9+
from typing import Any, IO, Sequence
910

1011
import numpy as np
1112

@@ -21,18 +22,16 @@
2122
TokenType,
2223
)
2324

25+
class WriterState(Enum):
26+
EMPTY = auto()
27+
HEADER = auto()
28+
KV_DATA = auto()
29+
TI_DATA = auto()
30+
2431
class GGUFWriter:
2532
fout: BufferedWriter
26-
arch: str
27-
offset_tensor = 0
28-
data_alignment = GGUF_DEFAULT_ALIGNMENT
29-
kv_data = b""
30-
kv_data_count = 0
31-
ti_data = b""
32-
ti_data_count = 0
33-
use_temp_file: bool
34-
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
35-
tensors: list[tuple[np.ndarray[Any, Any], int]]
33+
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
34+
tensors: list[np.ndarray[Any, Any]]
3635
_simple_value_packing = {
3736
GGUFValueType.UINT8: "B",
3837
GGUFValueType.INT8: "b",
@@ -60,27 +59,47 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool
6059
self.fout = open(path, "wb")
6160
self.arch = arch
6261
self.endianess = endianess
63-
self.add_architecture()
62+
self.offset_tensor = 0
63+
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
64+
self.kv_data = b""
65+
self.kv_data_count = 0
66+
self.ti_data = b""
67+
self.ti_data_count = 0
6468
self.use_temp_file = use_temp_file
69+
self.temp_file = None
6570
self.tensors = []
6671
print("gguf: This GGUF file is for {0} Endian only"
6772
.format("Big" if self.endianess == GGUFEndian.BIG else "Little"))
73+
self.state = WriterState.EMPTY
74+
75+
self.add_architecture()
6876

6977
def write_header_to_file(self) -> None:
78+
if self.state is not WriterState.EMPTY:
79+
raise ValueError(f'Expected output file to be empty, got {self.state}')
80+
7081
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
7182
self._write_packed("I", GGUF_VERSION)
7283
self._write_packed("Q", self.ti_data_count)
7384
self._write_packed("Q", self.kv_data_count)
7485
self.flush()
75-
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
86+
self.state = WriterState.HEADER
7687

7788
def write_kv_data_to_file(self) -> None:
89+
if self.state is not WriterState.HEADER:
90+
raise ValueError(f'Expected output file to contain the header, got {self.state}')
91+
7892
self.fout.write(self.kv_data)
7993
self.flush()
94+
self.state = WriterState.KV_DATA
8095

8196
def write_ti_data_to_file(self) -> None:
97+
if self.state is not WriterState.KV_DATA:
98+
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
99+
82100
self.fout.write(self.ti_data)
83101
self.flush()
102+
self.state = WriterState.TI_DATA
84103

85104
def add_key(self, key: str) -> None:
86105
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
@@ -173,6 +192,9 @@ def ggml_pad(x: int, n: int) -> int:
173192
return ((x + n - 1) // n) * n
174193

175194
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None) -> None:
195+
if self.state is not WriterState.EMPTY:
196+
raise ValueError(f'Expected output file to be empty, got {self.state}')
197+
176198
if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
177199
raise ValueError("Only F32 and F16 tensors are supported for now")
178200

@@ -203,23 +225,21 @@ def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequenc
203225
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
204226
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
205227

206-
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
207-
208-
if self.temp_file is None:
209-
self.tensors.append((tensor, pad))
210-
return
228+
if self.temp_file is None:
229+
self.tensors.append(tensor)
211230

212231
tensor.tofile(self.temp_file)
232+
self.write_padding(self.temp_file, tensor.nbytes)
213233

214-
if pad != 0:
215-
self.temp_file.write(bytes([0] * pad))
216-
217-
def write_padding(self, fp: BinaryIO, n: int, align: int | None = None) -> None:
234+
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None):
218235
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
219236
if pad != 0:
220237
fp.write(bytes([0] * pad))
221238

222239
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
240+
if self.state is not WriterState.TI_DATA:
241+
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
242+
223243
if self.endianess==GGUFEndian.BIG:
224244
tensor.byteswap(inplace=True)
225245
self.write_padding(self.fout, self.fout.tell())
@@ -232,10 +252,13 @@ def write_tensors_to_file(self) -> None:
232252
self.write_padding(self.fout, self.fout.tell())
233253

234254
if self.temp_file is None:
235-
for (currtensor, currpad) in self.tensors:
236-
currtensor.tofile(self.fout)
237-
if currpad != 0:
238-
self.fout.write(bytes([0] * currpad))
255+
while True:
256+
try:
257+
tensor = self.tensors.pop(0)
258+
except IndexError:
259+
break
260+
tensor.tofile(self.fout)
261+
self.write_padding(self.fout, tensor.nbytes)
239262
return
240263

241264
self.temp_file.seek(0)

gguf-py/gguf/vocab.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,8 @@
99
from .gguf_writer import GGUFWriter
1010

1111
class SpecialVocab:
12-
load_merges: bool = False
13-
merges: list[str] = []
14-
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
15-
special_token_ids: dict[str, int] = {}
16-
n_vocab: int | None = None
12+
merges: list[str]
13+
special_token_ids: dict[str, int]
1714

1815
def __init__(
1916
self, path: str | os.PathLike[str], load_merges: bool = False,
@@ -23,8 +20,11 @@ def __init__(
2320
self.special_token_ids = {}
2421
self.n_vocab = n_vocab
2522
self.load_merges = load_merges
23+
self.merges = []
2624
if special_token_types is not None:
2725
self.special_token_types = special_token_types
26+
else:
27+
self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad')
2828
self._load(Path(path))
2929

3030
def _load(self, path: Path) -> None:

0 commit comments

Comments
 (0)