diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml
index d57b9be4f3..8af218c053 100644
--- a/.github/workflows/python-tests.yml
+++ b/.github/workflows/python-tests.yml
@@ -88,7 +88,7 @@ jobs:
working-directory: ./src # For code coverage to work
run: |
source ../.venv/bin/activate
- PYTEST="python -m pytest --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)'"
+ PYTEST="python -m pytest --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 8 --reruns-delay 2 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)'"
case "${{ matrix.test_name }}" in
@@ -172,7 +172,7 @@ jobs:
working-directory: ./src # For code coverage to work
run: |
..\.venv\Scripts\activate
- python -m pytest -n 4 --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504|not less than or equal to 0.01)' ../tests
+ python -m pytest -n 4 --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 8 --reruns-delay 2 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504|not less than or equal to 0.01)' ../tests
# Upload code coverage
- name: Upload coverage reports to Codecov with GitHub Action
diff --git a/docs/source/en/package_reference/serialization.md b/docs/source/en/package_reference/serialization.md
index c2a7388091..841d9d3011 100644
--- a/docs/source/en/package_reference/serialization.md
+++ b/docs/source/en/package_reference/serialization.md
@@ -8,7 +8,11 @@ rendered properly in your Markdown viewer.
## Save torch state dict
-The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`].
+The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.
+
+If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it.
+
+[[autodoc]] huggingface_hub.save_torch_model
[[autodoc]] huggingface_hub.save_torch_state_dict
@@ -34,4 +38,8 @@ This is the underlying factory from which each framework-specific helper is deri
### get_torch_storage_id
-[[autodoc]] huggingface_hub.get_torch_storage_id
\ No newline at end of file
+[[autodoc]] huggingface_hub.get_torch_storage_id
+
+### get_torch_storage_size
+
+[[autodoc]] huggingface_hub.get_torch_storage_size
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 29c831ee9f..e13aa28f88 100644
--- a/setup.py
+++ b/setup.py
@@ -63,13 +63,14 @@ def get_version() -> str:
+ [
"jedi",
"Jinja2",
- "pytest>=8.1.1",
+ "pytest>=8.1.1,<8.2.2", # at least until 8.2.3 is released with https://github.com/pytest-dev/pytest/pull/12436
"pytest-cov",
"pytest-env",
"pytest-xdist",
"pytest-vcr", # to mock Inference
"pytest-asyncio", # for AsyncInferenceClient
"pytest-rerunfailures", # to rerun flaky tests in CI
+ "pytest-mock",
"urllib3<2.0", # VCR.py broken with urllib3 2.0 (see https://urllib3.readthedocs.io/en/stable/v2-migration-guide.html)
"soundfile",
"Pillow",
diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py
index 0b692f5b01..131496eb69 100644
--- a/src/huggingface_hub/__init__.py
+++ b/src/huggingface_hub/__init__.py
@@ -423,7 +423,10 @@
],
"serialization": [
"StateDictSplit",
+ "get_tf_storage_size",
"get_torch_storage_id",
+ "get_torch_storage_size",
+ "save_torch_model",
"save_torch_state_dict",
"split_state_dict_into_shards_factory",
"split_tf_state_dict_into_shards",
@@ -911,7 +914,10 @@ def __dir__():
from .repository import Repository # noqa: F401
from .serialization import (
StateDictSplit, # noqa: F401
+ get_tf_storage_size, # noqa: F401
get_torch_storage_id, # noqa: F401
+ get_torch_storage_size, # noqa: F401
+ save_torch_model, # noqa: F401
save_torch_state_dict, # noqa: F401
split_state_dict_into_shards_factory, # noqa: F401
split_tf_state_dict_into_shards, # noqa: F401
diff --git a/src/huggingface_hub/serialization/__init__.py b/src/huggingface_hub/serialization/__init__.py
index 2ae8f4aa1d..9e30ce175c 100644
--- a/src/huggingface_hub/serialization/__init__.py
+++ b/src/huggingface_hub/serialization/__init__.py
@@ -15,5 +15,11 @@
"""Contains helpers to serialize tensors."""
from ._base import StateDictSplit, split_state_dict_into_shards_factory
-from ._tensorflow import split_tf_state_dict_into_shards
-from ._torch import get_torch_storage_id, save_torch_state_dict, split_torch_state_dict_into_shards
+from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
+from ._torch import (
+ get_torch_storage_id,
+ get_torch_storage_size,
+ save_torch_model,
+ save_torch_state_dict,
+ split_torch_state_dict_into_shards,
+)
diff --git a/src/huggingface_hub/serialization/_base.py b/src/huggingface_hub/serialization/_base.py
index c08d39b5ae..c30df3c324 100644
--- a/src/huggingface_hub/serialization/_base.py
+++ b/src/huggingface_hub/serialization/_base.py
@@ -49,7 +49,7 @@ def __post_init__(self):
def split_state_dict_into_shards_factory(
state_dict: Dict[str, TensorT],
*,
- get_tensor_size: TensorSizeFn_T,
+ get_storage_size: TensorSizeFn_T,
filename_pattern: str,
get_storage_id: StorageIDFn_T = lambda tensor: None,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
@@ -72,8 +72,8 @@ def split_state_dict_into_shards_factory(
Args:
state_dict (`Dict[str, Tensor]`):
The state dictionary to save.
- get_tensor_size (`Callable[[Tensor], int]`):
- A function that returns the size of a tensor in bytes.
+ get_storage_size (`Callable[[Tensor], int]`):
+ A function that returns the size of a tensor when saved on disk in bytes.
get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
@@ -117,7 +117,7 @@ def split_state_dict_into_shards_factory(
storage_id_to_tensors[storage_id] = [key]
# Compute tensor size
- tensor_size = get_tensor_size(tensor)
+ tensor_size = get_storage_size(tensor)
# If this tensor is bigger than the maximal size, we put it in its own shard
if tensor_size > max_shard_size:
diff --git a/src/huggingface_hub/serialization/_tensorflow.py b/src/huggingface_hub/serialization/_tensorflow.py
index 943ff296b4..59ed8110b2 100644
--- a/src/huggingface_hub/serialization/_tensorflow.py
+++ b/src/huggingface_hub/serialization/_tensorflow.py
@@ -63,11 +63,11 @@ def split_tf_state_dict_into_shards(
state_dict,
max_shard_size=max_shard_size,
filename_pattern=filename_pattern,
- get_tensor_size=get_tensor_size,
+ get_storage_size=get_tf_storage_size,
)
-def get_tensor_size(tensor: "tf.Tensor") -> int:
+def get_tf_storage_size(tensor: "tf.Tensor") -> int:
# Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
# Better to overestimate than underestimate.
return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))
diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py
index 36bac7b284..ae0112e76d 100644
--- a/src/huggingface_hub/serialization/_torch.py
+++ b/src/huggingface_hub/serialization/_torch.py
@@ -17,9 +17,10 @@
import json
import os
import re
+from collections import defaultdict
from functools import lru_cache
from pathlib import Path
-from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from .. import constants, logging
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
@@ -31,27 +32,30 @@
import torch
-def split_torch_state_dict_into_shards(
- state_dict: Dict[str, "torch.Tensor"],
+def save_torch_model(
+ model: "torch.nn.Module",
+ save_directory: Union[str, Path],
*,
- filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
+ filename_pattern: Optional[str] = None,
+ force_contiguous: bool = True,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
-) -> StateDictSplit:
+ metadata: Optional[Dict[str, str]] = None,
+ safe_serialization: bool = True,
+):
"""
- Split a model state dictionary in shards so that each shard is smaller than a given size.
-
- The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
- made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
- have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
- [6+2+2GB], [6+2GB], [6GB].
+ Saves a given torch model to disk, handling sharding and shared tensors issues.
+ See also [`save_torch_state_dict`] to save a state dict with more flexibility.
-
+ For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
- To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
- `split_torch_state_dict_into_shards` under the hood.
+ The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
+ saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
+ an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
+ [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
+ safetensors (the default). Otherwise, the shards are saved as pickle.
-
+ Before saving the model, the `save_directory` is cleaned from any previous shard files.
@@ -61,49 +65,53 @@ def split_torch_state_dict_into_shards(
Args:
- state_dict (`Dict[str, torch.Tensor]`):
- The state dictionary to save.
+ model (`torch.nn.Module`):
+ The model to save on disk.
+ save_directory (`str` or `Path`):
+ The directory in which the model will be saved.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"`.
+ Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
+ parameter.
+ force_contiguous (`boolean`, *optional*):
+ Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
+ model, but it could potentially change performance if the layout of the tensor was chosen specifically for
+ that reason. Defaults to `True`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
-
- Returns:
- [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
+ metadata (`Dict[str, str]`, *optional*):
+ Extra information to save along with the model. Some metadata will be added for each dropped tensors.
+ This information will not be enough to recover the entire shared structure but might help understanding
+ things.
+ safe_serialization (`bool`, *optional*):
+ Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
+ Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
+ in a future version.
Example:
+
```py
- >>> import json
- >>> import os
- >>> from safetensors.torch import save_file as safe_save_file
- >>> from huggingface_hub import split_torch_state_dict_into_shards
+ >>> from huggingface_hub import save_torch_model
+ >>> model = ... # A PyTorch model
- >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
- ... state_dict_split = split_torch_state_dict_into_shards(state_dict)
- ... for filename, tensors in state_dict_split.filename_to_tensors.items():
- ... shard = {tensor: state_dict[tensor] for tensor in tensors}
- ... safe_save_file(
- ... shard,
- ... os.path.join(save_directory, filename),
- ... metadata={"format": "pt"},
- ... )
- ... if state_dict_split.is_sharded:
- ... index = {
- ... "metadata": state_dict_split.metadata,
- ... "weight_map": state_dict_split.tensor_to_filename,
- ... }
- ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
- ... f.write(json.dumps(index, indent=2))
+ # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
+ >>> save_torch_model(model, "path/to/folder")
+
+ # Load model back
+ >>> from huggingface_hub import load_torch_model # TODO
+ >>> load_torch_model(model, "path/to/folder")
+ >>>
```
"""
- return split_state_dict_into_shards_factory(
- state_dict,
- max_shard_size=max_shard_size,
+ save_torch_state_dict(
+ state_dict=model.state_dict(),
filename_pattern=filename_pattern,
- get_tensor_size=get_tensor_size,
- get_storage_id=get_torch_storage_id,
+ force_contiguous=force_contiguous,
+ max_shard_size=max_shard_size,
+ metadata=metadata,
+ safe_serialization=safe_serialization,
+ save_directory=save_directory,
)
@@ -111,12 +119,18 @@ def save_torch_state_dict(
state_dict: Dict[str, "torch.Tensor"],
save_directory: Union[str, Path],
*,
- safe_serialization: bool = True,
filename_pattern: Optional[str] = None,
+ force_contiguous: bool = True,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
+ metadata: Optional[Dict[str, str]] = None,
+ safe_serialization: bool = True,
) -> None:
"""
- Save a model state dictionary to the disk.
+ Save a model state dictionary to the disk, handling sharding and shared tensors issues.
+
+ See also [`save_torch_model`] to directly save a PyTorch model.
+
+ For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
@@ -138,17 +152,25 @@ def save_torch_state_dict(
The state dictionary to save.
save_directory (`str` or `Path`):
The directory in which the model will be saved.
- safe_serialization (`bool`, *optional*):
- Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
- Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
- in a future version.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
parameter.
+ force_contiguous (`boolean`, *optional*):
+ Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
+ model, but it could potentially change performance if the layout of the tensor was chosen specifically for
+ that reason. Defaults to `True`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
+ metadata (`Dict[str, str]`, *optional*):
+ Extra information to save along with the model. Some metadata will be added for each dropped tensors.
+ This information will not be enough to recover the entire shared structure but might help understanding
+ things.
+ safe_serialization (`bool`, *optional*):
+ Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
+ Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
+ in a future version.
Example:
@@ -189,6 +211,12 @@ def save_torch_state_dict(
"using safe serialization by installing `safetensors` with `pip install safetensors`."
)
+ # Clean state dict for safetensors
+ if metadata is None:
+ metadata = {}
+ if safe_serialization:
+ state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous)
+
# Split dict
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
@@ -205,7 +233,10 @@ def save_torch_state_dict(
logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...")
# Save each shard
- safe_file_kwargs = {"metadata": {"format": "pt"}} if safe_serialization else {}
+ per_file_metadata = {"format": "pt"}
+ if not state_dict_split.is_sharded:
+ per_file_metadata.update(metadata)
+ safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {}
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs)
@@ -214,7 +245,10 @@ def save_torch_state_dict(
# Save the index (if any)
if state_dict_split.is_sharded:
index_path = filename_pattern.format(suffix="") + ".index.json"
- index = {"metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename}
+ index = {
+ "metadata": {**state_dict_split.metadata, **metadata},
+ "weight_map": state_dict_split.tensor_to_filename,
+ }
with open(os.path.join(save_directory, index_path), "w") as f:
json.dump(index, f, indent=2)
logger.info(
@@ -226,6 +260,82 @@ def save_torch_state_dict(
logger.info(f"Model weights successfully saved to {save_directory}!")
+def split_torch_state_dict_into_shards(
+ state_dict: Dict[str, "torch.Tensor"],
+ *,
+ filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
+ max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
+) -> StateDictSplit:
+ """
+ Split a model state dictionary in shards so that each shard is smaller than a given size.
+
+ The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
+ made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
+ have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
+ [6+2+2GB], [6+2GB], [6GB].
+
+
+
+
+ To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
+ `split_torch_state_dict_into_shards` under the hood.
+
+
+
+
+
+ If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
+ size greater than `max_shard_size`.
+
+
+
+ Args:
+ state_dict (`Dict[str, torch.Tensor]`):
+ The state dictionary to save.
+ filename_pattern (`str`, *optional*):
+ The pattern to generate the files names in which the model will be saved. Pattern must be a string that
+ can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+ Defaults to `"model{suffix}.safetensors"`.
+ max_shard_size (`int` or `str`, *optional*):
+ The maximum size of each shard, in bytes. Defaults to 5GB.
+
+ Returns:
+ [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
+
+ Example:
+ ```py
+ >>> import json
+ >>> import os
+ >>> from safetensors.torch import save_file as safe_save_file
+ >>> from huggingface_hub import split_torch_state_dict_into_shards
+
+ >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
+ ... state_dict_split = split_torch_state_dict_into_shards(state_dict)
+ ... for filename, tensors in state_dict_split.filename_to_tensors.items():
+ ... shard = {tensor: state_dict[tensor] for tensor in tensors}
+ ... safe_save_file(
+ ... shard,
+ ... os.path.join(save_directory, filename),
+ ... metadata={"format": "pt"},
+ ... )
+ ... if state_dict_split.is_sharded:
+ ... index = {
+ ... "metadata": state_dict_split.metadata,
+ ... "weight_map": state_dict_split.tensor_to_filename,
+ ... }
+ ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
+ ... f.write(json.dumps(index, indent=2))
+ ```
+ """
+ return split_state_dict_into_shards_factory(
+ state_dict,
+ max_shard_size=max_shard_size,
+ filename_pattern=filename_pattern,
+ get_storage_size=get_torch_storage_size,
+ get_storage_id=get_torch_storage_id,
+ )
+
+
def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]:
"""
Return unique identifier to a tensor storage.
@@ -248,11 +358,23 @@ def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, i
else:
unique_id = storage_ptr(tensor)
- return tensor.device, unique_id, get_storage_size(tensor)
+ return tensor.device, unique_id, get_torch_storage_size(tensor)
-def get_tensor_size(tensor: "torch.Tensor") -> int:
- return tensor.numel() * tensor.element_size()
+def get_torch_storage_size(tensor: "torch.Tensor") -> int:
+ """
+ Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
+ """
+ try:
+ return tensor.untyped_storage().nbytes()
+ except AttributeError:
+ # Fallback for torch==1.10
+ try:
+ return tensor.storage().size() * _get_dtype_size(tensor.dtype)
+ except NotImplementedError:
+ # Fallback for meta storage
+ # On torch >=2.0 this is the tensor size
+ return tensor.nelement() * _get_dtype_size(tensor.dtype)
@lru_cache()
@@ -278,7 +400,7 @@ def is_torch_tpu_available(check_device=True):
def storage_ptr(tensor: "torch.Tensor") -> int:
"""
- Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L11C1-L20C21.
+ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
"""
try:
return tensor.untyped_storage().data_ptr()
@@ -291,20 +413,141 @@ def storage_ptr(tensor: "torch.Tensor") -> int:
return 0
-def get_storage_size(tensor: "torch.Tensor") -> int:
+def _clean_state_dict_for_safetensors(
+ state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True
+):
+ """Remove shared tensors from state_dict and update metadata accordingly (for reloading).
+
+ Warning: `state_dict` and `metadata` are mutated in-place!
+
+ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
"""
- Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
+ to_removes = _remove_duplicate_names(state_dict)
+ for kept_name, to_remove_group in to_removes.items():
+ for to_remove in to_remove_group:
+ if metadata is None:
+ metadata = {}
+
+ if to_remove not in metadata:
+ # Do not override user data
+ metadata[to_remove] = kept_name
+ del state_dict[to_remove]
+ if force_contiguous:
+ state_dict = {k: v.contiguous() for k, v in state_dict.items()}
+ return state_dict
+
+
+def _end_ptr(tensor: "torch.Tensor") -> int:
"""
- try:
- return tensor.untyped_storage().nbytes()
- except AttributeError:
- # Fallback for torch==1.10
- try:
- return tensor.storage().size() * _get_dtype_size(tensor.dtype)
- except NotImplementedError:
- # Fallback for meta storage
- # On torch >=2.0 this is the tensor size
- return tensor.nelement() * _get_dtype_size(tensor.dtype)
+ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23.
+ """
+ if tensor.nelement():
+ stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype)
+ else:
+ stop = tensor.data_ptr()
+ return stop
+
+
+def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
+ """
+ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44
+ """
+ filtered_tensors = []
+ for shared in tensors:
+ if len(shared) < 2:
+ filtered_tensors.append(shared)
+ continue
+
+ areas = []
+ for name in shared:
+ tensor = state_dict[name]
+ areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
+ areas.sort()
+
+ _, last_stop, last_name = areas[0]
+ filtered_tensors.append({last_name})
+ for start, stop, name in areas[1:]:
+ if start >= last_stop:
+ filtered_tensors.append({name})
+ else:
+ filtered_tensors[-1].add(name)
+ last_stop = stop
+
+ return filtered_tensors
+
+
+def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
+ """
+ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69.
+ """
+ import torch
+
+ tensors_dict = defaultdict(set)
+ for k, v in state_dict.items():
+ if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0:
+ # Need to add device as key because of multiple GPU.
+ tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k)
+ tensors = list(sorted(tensors_dict.values()))
+ tensors = _filter_shared_not_shared(tensors, state_dict)
+ return tensors
+
+
+def _is_complete(tensor: "torch.Tensor") -> bool:
+ """
+ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
+ """
+ return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
+ tensor.dtype
+ ) == get_torch_storage_size(tensor)
+
+
+def _remove_duplicate_names(
+ state_dict: Dict[str, "torch.Tensor"],
+ *,
+ preferred_names: Optional[List[str]] = None,
+ discard_names: Optional[List[str]] = None,
+) -> Dict[str, List[str]]:
+ """
+ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
+ """
+ if preferred_names is None:
+ preferred_names = []
+ unique_preferred_names = set(preferred_names)
+ if discard_names is None:
+ discard_names = []
+ unique_discard_names = set(discard_names)
+
+ shareds = _find_shared_tensors(state_dict)
+ to_remove = defaultdict(list)
+ for shared in shareds:
+ complete_names = set([name for name in shared if _is_complete(state_dict[name])])
+ if not complete_names:
+ raise RuntimeError(
+ "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
+ f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model"
+ " since you could be storing much more memory than needed. Please refer to"
+ " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
+ " issue."
+ )
+
+ keep_name = sorted(list(complete_names))[0]
+
+ # Mechanism to preferentially select keys to keep
+ # coming from the on-disk file to allow
+ # loading models saved with a different choice
+ # of keep_name
+ preferred = complete_names.difference(unique_discard_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+
+ if unique_preferred_names:
+ preferred = unique_preferred_names.intersection(complete_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+ for name in sorted(shared):
+ if name != keep_name:
+ to_remove[keep_name].append(name)
+ return to_remove
@lru_cache()
diff --git a/tests/test_serialization.py b/tests/test_serialization.py
index 9af9256389..d954fce99b 100644
--- a/tests/test_serialization.py
+++ b/tests/test_serialization.py
@@ -1,13 +1,20 @@
import json
+import struct
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List
+from unittest.mock import Mock
import pytest
-
-from huggingface_hub.serialization import save_torch_state_dict, split_state_dict_into_shards_factory
+from pytest_mock import MockerFixture
+
+from huggingface_hub.serialization import (
+ get_tf_storage_size,
+ get_torch_storage_size,
+ save_torch_model,
+ save_torch_state_dict,
+ split_state_dict_into_shards_factory,
+)
from huggingface_hub.serialization._base import parse_size_to_int
-from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow
-from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch
from .testing_utils import requires
@@ -20,7 +27,7 @@ def _dummy_get_storage_id(item):
return None
-def _dummy_get_tensor_size(item):
+def _dummy_get_storage_size(item):
return sum(item)
@@ -51,11 +58,28 @@ def torch_state_dict() -> Dict[str, "torch.Tensor"]:
pytest.skip("torch is not available")
+@pytest.fixture
+def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
+ try:
+ import torch
+
+ shared_layer = torch.tensor([4])
+
+ return {
+ "shared_1": shared_layer,
+ "unique_1": torch.tensor([10]),
+ "unique_2": torch.tensor([30]),
+ "shared_2": shared_layer,
+ }
+ except ImportError:
+ pytest.skip("torch is not available")
+
+
def test_single_shard(dummy_state_dict):
state_dict_split = split_state_dict_into_shards_factory(
dummy_state_dict,
get_storage_id=_dummy_get_storage_id,
- get_tensor_size=_dummy_get_tensor_size,
+ get_storage_size=_dummy_get_storage_size,
max_shard_size=100, # large shard size => only one shard
filename_pattern="file{suffix}.dummy",
)
@@ -78,7 +102,7 @@ def test_multiple_shards(dummy_state_dict):
state_dict_split = split_state_dict_into_shards_factory(
dummy_state_dict,
get_storage_id=_dummy_get_storage_id,
- get_tensor_size=_dummy_get_tensor_size,
+ get_storage_size=_dummy_get_storage_size,
max_shard_size=10, # small shard size => multiple shards
filename_pattern="file{suffix}.dummy",
)
@@ -111,7 +135,7 @@ def test_tensor_same_storage():
"layer_5": [1],
},
get_storage_id=lambda x: (x[0]), # dummy for test: storage id based on first element
- get_tensor_size=_dummy_get_tensor_size,
+ get_storage_size=_dummy_get_storage_size,
max_shard_size=1,
filename_pattern="model{suffix}.safetensors",
)
@@ -131,19 +155,19 @@ def test_tensor_same_storage():
@requires("tensorflow")
-def test_get_tensor_size_tensorflow():
+def test_get_tf_storage_size():
import tensorflow as tf
- assert get_tensor_size_tensorflow(tf.constant([1, 2, 3, 4, 5], dtype=tf.float64)) == 5 * 8
- assert get_tensor_size_tensorflow(tf.constant([1, 2, 3, 4, 5], dtype=tf.float16)) == 5 * 2
+ assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float64)) == 5 * 8
+ assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float16)) == 5 * 2
@requires("torch")
-def test_get_tensor_size_torch():
+def test_get_torch_storage_size():
import torch
- assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8
- assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2
+ assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8
+ assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2
def test_parse_size_to_int():
@@ -160,6 +184,30 @@ def test_parse_size_to_int():
parse_size_to_int("1ooKB") # not a float
+def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
+ """Test `save_torch_model` is only a wrapper around `save_torch_state_dict`."""
+ model_mock = Mock()
+ safe_state_dict_mock = mocker.patch("huggingface_hub.serialization._torch.save_torch_state_dict")
+ save_torch_model(
+ model_mock,
+ save_directory=tmp_path,
+ filename_pattern="my-pattern",
+ force_contiguous=True,
+ max_shard_size="3GB",
+ metadata={"foo": "bar"},
+ safe_serialization=True,
+ )
+ safe_state_dict_mock.assert_called_once_with(
+ state_dict=model_mock.state_dict.return_value,
+ save_directory=tmp_path,
+ filename_pattern="my-pattern",
+ force_contiguous=True,
+ max_shard_size="3GB",
+ metadata={"foo": "bar"},
+ safe_serialization=True,
+ )
+
+
def test_save_torch_state_dict_not_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None:
"""Save as safetensors without sharding."""
save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB")
@@ -225,6 +273,47 @@ def test_save_torch_state_dict_unsafe_sharded(
}
+def test_save_torch_state_dict_shared_layers_not_sharded(
+ tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"]
+) -> None:
+ from safetensors.torch import load_file
+
+ save_torch_state_dict(torch_state_dict_shared_layers, tmp_path, safe_serialization=True)
+ safetensors_file = tmp_path / "model.safetensors"
+ assert safetensors_file.is_file()
+
+ # Check shared layer not duplicated in file
+ state_dict = load_file(safetensors_file)
+ assert "shared_1" in state_dict
+ assert "shared_2" not in state_dict
+
+ # Check shared layer info in metadata
+ file_bytes = safetensors_file.read_bytes()
+ metadata_str = file_bytes[
+ 8 : struct.unpack(" None:
+ from safetensors.torch import load_file
+
+ save_torch_state_dict(torch_state_dict_shared_layers, tmp_path, max_shard_size=2, safe_serialization=True)
+ index_file = tmp_path / "model.safetensors.index.json"
+ assert index_file.is_file()
+
+ # Check shared layer info in index metadata
+ index = json.loads(index_file.read_text())
+ assert index["metadata"]["shared_2"] == "shared_1"
+
+ # Check shared layer not duplicated in files
+ for filename in index["weight_map"].values():
+ state_dict = load_file(tmp_path / filename)
+ assert "shared_2" not in state_dict
+
+
def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None:
"""Custom filename pattern is respected."""
# Not sharded