Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use mmap in torch load #4202

Merged
merged 2 commits into from
Nov 25, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use mmap in torch load, prefer .bin files when loading
  • Loading branch information
Galunid committed Nov 24, 2023
commit b9276f3b59a5813d432cde53fdb6c9fcb3450122
32 changes: 16 additions & 16 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
self.fname_out = fname_out
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.is_safetensors = self._is_model_safetensors()
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
self.is_pytorch = self._is_model_pytorch()
self.num_parts = Model.count_model_parts(self.dir_model, ".bin" if self.is_pytorch else ".safetensors")
self.part_names = self._get_part_names()
self.hparams = Model.load_hparams(self.dir_model)
self.model_arch = self._get_model_architecture()
Expand All @@ -55,15 +55,15 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for part_name in self.part_names:
print(f"gguf: loading model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
if self.is_pytorch:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
else:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else:
ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu"))

with ctx as model_part:
for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
data = model_part[name] if self.is_pytorch else model_part.get_tensor(name)
yield name, data

def set_gguf_parameters(self):
Expand Down Expand Up @@ -170,18 +170,18 @@ def from_model_architecture(model_architecture):
return StableLMModel
return Model

def _is_model_safetensors(self) -> bool:
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
def _is_model_pytorch(self) -> bool:
return Model.count_model_parts(self.dir_model, ".bin") > 0

def _get_part_names(self):
if self.is_safetensors:
if self.num_parts == 1: # there's only one .safetensors file
return ("model.safetensors",)
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))

if self.num_parts == 1: # there's only one .bin file
return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
if self.is_pytorch:
if self.num_parts == 1: # there's only one .bin file
return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))

if self.num_parts == 1: # there's only one .safetensors file
return ("model.safetensors",)
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))

def _get_model_architecture(self) -> gguf.MODEL_ARCH:
arch = self.hparams["architectures"][0]
Expand Down