diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index d57b9be4f3..8af218c053 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -88,7 +88,7 @@ jobs: working-directory: ./src # For code coverage to work run: | source ../.venv/bin/activate - PYTEST="python -m pytest --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)'" + PYTEST="python -m pytest --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 8 --reruns-delay 2 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)'" case "${{ matrix.test_name }}" in @@ -172,7 +172,7 @@ jobs: working-directory: ./src # For code coverage to work run: | ..\.venv\Scripts\activate - python -m pytest -n 4 --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504|not less than or equal to 0.01)' ../tests + python -m pytest -n 4 --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 8 --reruns-delay 2 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504|not less than or equal to 0.01)' ../tests # Upload code coverage - name: Upload coverage reports to Codecov with GitHub Action diff --git a/docs/source/en/package_reference/serialization.md b/docs/source/en/package_reference/serialization.md index c2a7388091..841d9d3011 100644 --- a/docs/source/en/package_reference/serialization.md +++ b/docs/source/en/package_reference/serialization.md @@ -8,7 +8,11 @@ rendered properly in your Markdown viewer. ## Save torch state dict -The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`]. +The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported. + +If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it. + +[[autodoc]] huggingface_hub.save_torch_model [[autodoc]] huggingface_hub.save_torch_state_dict @@ -34,4 +38,8 @@ This is the underlying factory from which each framework-specific helper is deri ### get_torch_storage_id -[[autodoc]] huggingface_hub.get_torch_storage_id \ No newline at end of file +[[autodoc]] huggingface_hub.get_torch_storage_id + +### get_torch_storage_size + +[[autodoc]] huggingface_hub.get_torch_storage_size \ No newline at end of file diff --git a/setup.py b/setup.py index 29c831ee9f..e13aa28f88 100644 --- a/setup.py +++ b/setup.py @@ -63,13 +63,14 @@ def get_version() -> str: + [ "jedi", "Jinja2", - "pytest>=8.1.1", + "pytest>=8.1.1,<8.2.2", # at least until 8.2.3 is released with https://github.com/pytest-dev/pytest/pull/12436 "pytest-cov", "pytest-env", "pytest-xdist", "pytest-vcr", # to mock Inference "pytest-asyncio", # for AsyncInferenceClient "pytest-rerunfailures", # to rerun flaky tests in CI + "pytest-mock", "urllib3<2.0", # VCR.py broken with urllib3 2.0 (see https://urllib3.readthedocs.io/en/stable/v2-migration-guide.html) "soundfile", "Pillow", diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 0b692f5b01..131496eb69 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -423,7 +423,10 @@ ], "serialization": [ "StateDictSplit", + "get_tf_storage_size", "get_torch_storage_id", + "get_torch_storage_size", + "save_torch_model", "save_torch_state_dict", "split_state_dict_into_shards_factory", "split_tf_state_dict_into_shards", @@ -911,7 +914,10 @@ def __dir__(): from .repository import Repository # noqa: F401 from .serialization import ( StateDictSplit, # noqa: F401 + get_tf_storage_size, # noqa: F401 get_torch_storage_id, # noqa: F401 + get_torch_storage_size, # noqa: F401 + save_torch_model, # noqa: F401 save_torch_state_dict, # noqa: F401 split_state_dict_into_shards_factory, # noqa: F401 split_tf_state_dict_into_shards, # noqa: F401 diff --git a/src/huggingface_hub/serialization/__init__.py b/src/huggingface_hub/serialization/__init__.py index 2ae8f4aa1d..9e30ce175c 100644 --- a/src/huggingface_hub/serialization/__init__.py +++ b/src/huggingface_hub/serialization/__init__.py @@ -15,5 +15,11 @@ """Contains helpers to serialize tensors.""" from ._base import StateDictSplit, split_state_dict_into_shards_factory -from ._tensorflow import split_tf_state_dict_into_shards -from ._torch import get_torch_storage_id, save_torch_state_dict, split_torch_state_dict_into_shards +from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards +from ._torch import ( + get_torch_storage_id, + get_torch_storage_size, + save_torch_model, + save_torch_state_dict, + split_torch_state_dict_into_shards, +) diff --git a/src/huggingface_hub/serialization/_base.py b/src/huggingface_hub/serialization/_base.py index c08d39b5ae..c30df3c324 100644 --- a/src/huggingface_hub/serialization/_base.py +++ b/src/huggingface_hub/serialization/_base.py @@ -49,7 +49,7 @@ def __post_init__(self): def split_state_dict_into_shards_factory( state_dict: Dict[str, TensorT], *, - get_tensor_size: TensorSizeFn_T, + get_storage_size: TensorSizeFn_T, filename_pattern: str, get_storage_id: StorageIDFn_T = lambda tensor: None, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, @@ -72,8 +72,8 @@ def split_state_dict_into_shards_factory( Args: state_dict (`Dict[str, Tensor]`): The state dictionary to save. - get_tensor_size (`Callable[[Tensor], int]`): - A function that returns the size of a tensor in bytes. + get_storage_size (`Callable[[Tensor], int]`): + A function that returns the size of a tensor when saved on disk in bytes. get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*): A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage @@ -117,7 +117,7 @@ def split_state_dict_into_shards_factory( storage_id_to_tensors[storage_id] = [key] # Compute tensor size - tensor_size = get_tensor_size(tensor) + tensor_size = get_storage_size(tensor) # If this tensor is bigger than the maximal size, we put it in its own shard if tensor_size > max_shard_size: diff --git a/src/huggingface_hub/serialization/_tensorflow.py b/src/huggingface_hub/serialization/_tensorflow.py index 943ff296b4..59ed8110b2 100644 --- a/src/huggingface_hub/serialization/_tensorflow.py +++ b/src/huggingface_hub/serialization/_tensorflow.py @@ -63,11 +63,11 @@ def split_tf_state_dict_into_shards( state_dict, max_shard_size=max_shard_size, filename_pattern=filename_pattern, - get_tensor_size=get_tensor_size, + get_storage_size=get_tf_storage_size, ) -def get_tensor_size(tensor: "tf.Tensor") -> int: +def get_tf_storage_size(tensor: "tf.Tensor") -> int: # Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool). # Better to overestimate than underestimate. return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype)) diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index 36bac7b284..ae0112e76d 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -17,9 +17,10 @@ import json import os import re +from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union from .. import constants, logging from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -31,27 +32,30 @@ import torch -def split_torch_state_dict_into_shards( - state_dict: Dict[str, "torch.Tensor"], +def save_torch_model( + model: "torch.nn.Module", + save_directory: Union[str, Path], *, - filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, + filename_pattern: Optional[str] = None, + force_contiguous: bool = True, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, -) -> StateDictSplit: + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, +): """ - Split a model state dictionary in shards so that each shard is smaller than a given size. - - The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization - made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we - have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not - [6+2+2GB], [6+2GB], [6GB]. + Saves a given torch model to disk, handling sharding and shared tensors issues. + See also [`save_torch_state_dict`] to save a state dict with more flexibility. - + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). - To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses - `split_torch_state_dict_into_shards` under the hood. + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. - + Before saving the model, the `save_directory` is cleaned from any previous shard files. @@ -61,49 +65,53 @@ def split_torch_state_dict_into_shards( Args: - state_dict (`Dict[str, torch.Tensor]`): - The state dictionary to save. + model (`torch.nn.Module`): + The model to save on disk. + save_directory (`str` or `Path`): + The directory in which the model will be saved. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` - Defaults to `"model{suffix}.safetensors"`. + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. - - Returns: - [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. Example: + ```py - >>> import json - >>> import os - >>> from safetensors.torch import save_file as safe_save_file - >>> from huggingface_hub import split_torch_state_dict_into_shards + >>> from huggingface_hub import save_torch_model + >>> model = ... # A PyTorch model - >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): - ... state_dict_split = split_torch_state_dict_into_shards(state_dict) - ... for filename, tensors in state_dict_split.filename_to_tensors.items(): - ... shard = {tensor: state_dict[tensor] for tensor in tensors} - ... safe_save_file( - ... shard, - ... os.path.join(save_directory, filename), - ... metadata={"format": "pt"}, - ... ) - ... if state_dict_split.is_sharded: - ... index = { - ... "metadata": state_dict_split.metadata, - ... "weight_map": state_dict_split.tensor_to_filename, - ... } - ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: - ... f.write(json.dumps(index, indent=2)) + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> save_torch_model(model, "path/to/folder") + + # Load model back + >>> from huggingface_hub import load_torch_model # TODO + >>> load_torch_model(model, "path/to/folder") + >>> ``` """ - return split_state_dict_into_shards_factory( - state_dict, - max_shard_size=max_shard_size, + save_torch_state_dict( + state_dict=model.state_dict(), filename_pattern=filename_pattern, - get_tensor_size=get_tensor_size, - get_storage_id=get_torch_storage_id, + force_contiguous=force_contiguous, + max_shard_size=max_shard_size, + metadata=metadata, + safe_serialization=safe_serialization, + save_directory=save_directory, ) @@ -111,12 +119,18 @@ def save_torch_state_dict( state_dict: Dict[str, "torch.Tensor"], save_directory: Union[str, Path], *, - safe_serialization: bool = True, filename_pattern: Optional[str] = None, + force_contiguous: bool = True, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, ) -> None: """ - Save a model state dictionary to the disk. + Save a model state dictionary to the disk, handling sharding and shared tensors issues. + + See also [`save_torch_model`] to directly save a PyTorch model. + + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, @@ -138,17 +152,25 @@ def save_torch_state_dict( The state dictionary to save. save_directory (`str` or `Path`): The directory in which the model will be saved. - safe_serialization (`bool`, *optional*): - Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. - Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed - in a future version. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. Example: @@ -189,6 +211,12 @@ def save_torch_state_dict( "using safe serialization by installing `safetensors` with `pip install safetensors`." ) + # Clean state dict for safetensors + if metadata is None: + metadata = {} + if safe_serialization: + state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous) + # Split dict state_dict_split = split_torch_state_dict_into_shards( state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size @@ -205,7 +233,10 @@ def save_torch_state_dict( logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...") # Save each shard - safe_file_kwargs = {"metadata": {"format": "pt"}} if safe_serialization else {} + per_file_metadata = {"format": "pt"} + if not state_dict_split.is_sharded: + per_file_metadata.update(metadata) + safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {} for filename, tensors in state_dict_split.filename_to_tensors.items(): shard = {tensor: state_dict[tensor] for tensor in tensors} save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) @@ -214,7 +245,10 @@ def save_torch_state_dict( # Save the index (if any) if state_dict_split.is_sharded: index_path = filename_pattern.format(suffix="") + ".index.json" - index = {"metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename} + index = { + "metadata": {**state_dict_split.metadata, **metadata}, + "weight_map": state_dict_split.tensor_to_filename, + } with open(os.path.join(save_directory, index_path), "w") as f: json.dump(index, f, indent=2) logger.info( @@ -226,6 +260,82 @@ def save_torch_state_dict( logger.info(f"Model weights successfully saved to {save_directory}!") +def split_torch_state_dict_into_shards( + state_dict: Dict[str, "torch.Tensor"], + *, + filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + + To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses + `split_torch_state_dict_into_shards` under the hood. + + + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + + Example: + ```py + >>> import json + >>> import os + >>> from safetensors.torch import save_file as safe_save_file + >>> from huggingface_hub import split_torch_state_dict_into_shards + + >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): + ... state_dict_split = split_torch_state_dict_into_shards(state_dict) + ... for filename, tensors in state_dict_split.filename_to_tensors.items(): + ... shard = {tensor: state_dict[tensor] for tensor in tensors} + ... safe_save_file( + ... shard, + ... os.path.join(save_directory, filename), + ... metadata={"format": "pt"}, + ... ) + ... if state_dict_split.is_sharded: + ... index = { + ... "metadata": state_dict_split.metadata, + ... "weight_map": state_dict_split.tensor_to_filename, + ... } + ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + ... f.write(json.dumps(index, indent=2)) + ``` + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_torch_storage_size, + get_storage_id=get_torch_storage_id, + ) + + def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]: """ Return unique identifier to a tensor storage. @@ -248,11 +358,23 @@ def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, i else: unique_id = storage_ptr(tensor) - return tensor.device, unique_id, get_storage_size(tensor) + return tensor.device, unique_id, get_torch_storage_size(tensor) -def get_tensor_size(tensor: "torch.Tensor") -> int: - return tensor.numel() * tensor.element_size() +def get_torch_storage_size(tensor: "torch.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 + """ + try: + return tensor.untyped_storage().nbytes() + except AttributeError: + # Fallback for torch==1.10 + try: + return tensor.storage().size() * _get_dtype_size(tensor.dtype) + except NotImplementedError: + # Fallback for meta storage + # On torch >=2.0 this is the tensor size + return tensor.nelement() * _get_dtype_size(tensor.dtype) @lru_cache() @@ -278,7 +400,7 @@ def is_torch_tpu_available(check_device=True): def storage_ptr(tensor: "torch.Tensor") -> int: """ - Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L11C1-L20C21. + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11. """ try: return tensor.untyped_storage().data_ptr() @@ -291,20 +413,141 @@ def storage_ptr(tensor: "torch.Tensor") -> int: return 0 -def get_storage_size(tensor: "torch.Tensor") -> int: +def _clean_state_dict_for_safetensors( + state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True +): + """Remove shared tensors from state_dict and update metadata accordingly (for reloading). + + Warning: `state_dict` and `metadata` are mutated in-place! + + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155. """ - Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 + to_removes = _remove_duplicate_names(state_dict) + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if metadata is None: + metadata = {} + + if to_remove not in metadata: + # Do not override user data + metadata[to_remove] = kept_name + del state_dict[to_remove] + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + return state_dict + + +def _end_ptr(tensor: "torch.Tensor") -> int: """ - try: - return tensor.untyped_storage().nbytes() - except AttributeError: - # Fallback for torch==1.10 - try: - return tensor.storage().size() * _get_dtype_size(tensor.dtype) - except NotImplementedError: - # Fallback for meta storage - # On torch >=2.0 this is the tensor size - return tensor.nelement() * _get_dtype_size(tensor.dtype) + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23. + """ + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype) + else: + stop = tensor.data_ptr() + return stop + + +def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44 + """ + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + + return filtered_tensors + + +def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69. + """ + import torch + + tensors_dict = defaultdict(set) + for k, v in state_dict.items(): + if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0: + # Need to add device as key because of multiple GPU. + tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k) + tensors = list(sorted(tensors_dict.values())) + tensors = _filter_shared_not_shared(tensors, state_dict) + return tensors + + +def _is_complete(tensor: "torch.Tensor") -> bool: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size( + tensor.dtype + ) == get_torch_storage_size(tensor) + + +def _remove_duplicate_names( + state_dict: Dict[str, "torch.Tensor"], + *, + preferred_names: Optional[List[str]] = None, + discard_names: Optional[List[str]] = None, +) -> Dict[str, List[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + if preferred_names is None: + preferred_names = [] + unique_preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + unique_discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError( + "Error while trying to find names to remove to save state dict, but found no suitable name to keep" + f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model" + " since you could be storing much more memory than needed. Please refer to" + " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an" + " issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(unique_discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if unique_preferred_names: + preferred = unique_preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove @lru_cache() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 9af9256389..d954fce99b 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,13 +1,20 @@ import json +import struct from pathlib import Path from typing import TYPE_CHECKING, Dict, List +from unittest.mock import Mock import pytest - -from huggingface_hub.serialization import save_torch_state_dict, split_state_dict_into_shards_factory +from pytest_mock import MockerFixture + +from huggingface_hub.serialization import ( + get_tf_storage_size, + get_torch_storage_size, + save_torch_model, + save_torch_state_dict, + split_state_dict_into_shards_factory, +) from huggingface_hub.serialization._base import parse_size_to_int -from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow -from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch from .testing_utils import requires @@ -20,7 +27,7 @@ def _dummy_get_storage_id(item): return None -def _dummy_get_tensor_size(item): +def _dummy_get_storage_size(item): return sum(item) @@ -51,11 +58,28 @@ def torch_state_dict() -> Dict[str, "torch.Tensor"]: pytest.skip("torch is not available") +@pytest.fixture +def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]: + try: + import torch + + shared_layer = torch.tensor([4]) + + return { + "shared_1": shared_layer, + "unique_1": torch.tensor([10]), + "unique_2": torch.tensor([30]), + "shared_2": shared_layer, + } + except ImportError: + pytest.skip("torch is not available") + + def test_single_shard(dummy_state_dict): state_dict_split = split_state_dict_into_shards_factory( dummy_state_dict, get_storage_id=_dummy_get_storage_id, - get_tensor_size=_dummy_get_tensor_size, + get_storage_size=_dummy_get_storage_size, max_shard_size=100, # large shard size => only one shard filename_pattern="file{suffix}.dummy", ) @@ -78,7 +102,7 @@ def test_multiple_shards(dummy_state_dict): state_dict_split = split_state_dict_into_shards_factory( dummy_state_dict, get_storage_id=_dummy_get_storage_id, - get_tensor_size=_dummy_get_tensor_size, + get_storage_size=_dummy_get_storage_size, max_shard_size=10, # small shard size => multiple shards filename_pattern="file{suffix}.dummy", ) @@ -111,7 +135,7 @@ def test_tensor_same_storage(): "layer_5": [1], }, get_storage_id=lambda x: (x[0]), # dummy for test: storage id based on first element - get_tensor_size=_dummy_get_tensor_size, + get_storage_size=_dummy_get_storage_size, max_shard_size=1, filename_pattern="model{suffix}.safetensors", ) @@ -131,19 +155,19 @@ def test_tensor_same_storage(): @requires("tensorflow") -def test_get_tensor_size_tensorflow(): +def test_get_tf_storage_size(): import tensorflow as tf - assert get_tensor_size_tensorflow(tf.constant([1, 2, 3, 4, 5], dtype=tf.float64)) == 5 * 8 - assert get_tensor_size_tensorflow(tf.constant([1, 2, 3, 4, 5], dtype=tf.float16)) == 5 * 2 + assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float64)) == 5 * 8 + assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float16)) == 5 * 2 @requires("torch") -def test_get_tensor_size_torch(): +def test_get_torch_storage_size(): import torch - assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8 - assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2 + assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8 + assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2 def test_parse_size_to_int(): @@ -160,6 +184,30 @@ def test_parse_size_to_int(): parse_size_to_int("1ooKB") # not a float +def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None: + """Test `save_torch_model` is only a wrapper around `save_torch_state_dict`.""" + model_mock = Mock() + safe_state_dict_mock = mocker.patch("huggingface_hub.serialization._torch.save_torch_state_dict") + save_torch_model( + model_mock, + save_directory=tmp_path, + filename_pattern="my-pattern", + force_contiguous=True, + max_shard_size="3GB", + metadata={"foo": "bar"}, + safe_serialization=True, + ) + safe_state_dict_mock.assert_called_once_with( + state_dict=model_mock.state_dict.return_value, + save_directory=tmp_path, + filename_pattern="my-pattern", + force_contiguous=True, + max_shard_size="3GB", + metadata={"foo": "bar"}, + safe_serialization=True, + ) + + def test_save_torch_state_dict_not_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: """Save as safetensors without sharding.""" save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB") @@ -225,6 +273,47 @@ def test_save_torch_state_dict_unsafe_sharded( } +def test_save_torch_state_dict_shared_layers_not_sharded( + tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] +) -> None: + from safetensors.torch import load_file + + save_torch_state_dict(torch_state_dict_shared_layers, tmp_path, safe_serialization=True) + safetensors_file = tmp_path / "model.safetensors" + assert safetensors_file.is_file() + + # Check shared layer not duplicated in file + state_dict = load_file(safetensors_file) + assert "shared_1" in state_dict + assert "shared_2" not in state_dict + + # Check shared layer info in metadata + file_bytes = safetensors_file.read_bytes() + metadata_str = file_bytes[ + 8 : struct.unpack(" None: + from safetensors.torch import load_file + + save_torch_state_dict(torch_state_dict_shared_layers, tmp_path, max_shard_size=2, safe_serialization=True) + index_file = tmp_path / "model.safetensors.index.json" + assert index_file.is_file() + + # Check shared layer info in index metadata + index = json.loads(index_file.read_text()) + assert index["metadata"]["shared_2"] == "shared_1" + + # Check shared layer not duplicated in files + for filename in index["weight_map"].values(): + state_dict = load_file(tmp_path / filename) + assert "shared_2" not in state_dict + + def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: """Custom filename pattern is respected.""" # Not sharded