Skip to content

Commit

Permalink
Handle shared layers in save_torch_state_dict + add `save_torch_mod…
Browse files Browse the repository at this point in the history
…el` (#2373)

* Handle shared layers in save_torch_state_dict + save_torch_model + some helpers

* fix pytest rerun

* more reruns
  • Loading branch information
Wauplin authored Jul 11, 2024
1 parent d624b46 commit dfe72d0
Show file tree
Hide file tree
Showing 9 changed files with 450 additions and 97 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ jobs:
working-directory: ./src # For code coverage to work
run: |
source ../.venv/bin/activate
PYTEST="python -m pytest --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)'"
PYTEST="python -m pytest --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 8 --reruns-delay 2 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)'"
case "${{ matrix.test_name }}" in
Expand Down Expand Up @@ -172,7 +172,7 @@ jobs:
working-directory: ./src # For code coverage to work
run: |
..\.venv\Scripts\activate
python -m pytest -n 4 --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504|not less than or equal to 0.01)' ../tests
python -m pytest -n 4 --cov=./huggingface_hub --cov-report=xml:../coverage.xml --vcr-record=none --reruns 8 --reruns-delay 2 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504|not less than or equal to 0.01)' ../tests
# Upload code coverage
- name: Upload coverage reports to Codecov with GitHub Action
Expand Down
12 changes: 10 additions & 2 deletions docs/source/en/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ rendered properly in your Markdown viewer.

## Save torch state dict

The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`].
The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.

If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it.

[[autodoc]] huggingface_hub.save_torch_model

[[autodoc]] huggingface_hub.save_torch_state_dict

Expand All @@ -34,4 +38,8 @@ This is the underlying factory from which each framework-specific helper is deri

### get_torch_storage_id

[[autodoc]] huggingface_hub.get_torch_storage_id
[[autodoc]] huggingface_hub.get_torch_storage_id

### get_torch_storage_size

[[autodoc]] huggingface_hub.get_torch_storage_size
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def get_version() -> str:
+ [
"jedi",
"Jinja2",
"pytest>=8.1.1",
"pytest>=8.1.1,<8.2.2", # at least until 8.2.3 is released with https://github.com/pytest-dev/pytest/pull/12436
"pytest-cov",
"pytest-env",
"pytest-xdist",
"pytest-vcr", # to mock Inference
"pytest-asyncio", # for AsyncInferenceClient
"pytest-rerunfailures", # to rerun flaky tests in CI
"pytest-mock",
"urllib3<2.0", # VCR.py broken with urllib3 2.0 (see https://urllib3.readthedocs.io/en/stable/v2-migration-guide.html)
"soundfile",
"Pillow",
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@
],
"serialization": [
"StateDictSplit",
"get_tf_storage_size",
"get_torch_storage_id",
"get_torch_storage_size",
"save_torch_model",
"save_torch_state_dict",
"split_state_dict_into_shards_factory",
"split_tf_state_dict_into_shards",
Expand Down Expand Up @@ -911,7 +914,10 @@ def __dir__():
from .repository import Repository # noqa: F401
from .serialization import (
StateDictSplit, # noqa: F401
get_tf_storage_size, # noqa: F401
get_torch_storage_id, # noqa: F401
get_torch_storage_size, # noqa: F401
save_torch_model, # noqa: F401
save_torch_state_dict, # noqa: F401
split_state_dict_into_shards_factory, # noqa: F401
split_tf_state_dict_into_shards, # noqa: F401
Expand Down
10 changes: 8 additions & 2 deletions src/huggingface_hub/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,11 @@
"""Contains helpers to serialize tensors."""

from ._base import StateDictSplit, split_state_dict_into_shards_factory
from ._tensorflow import split_tf_state_dict_into_shards
from ._torch import get_torch_storage_id, save_torch_state_dict, split_torch_state_dict_into_shards
from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
from ._torch import (
get_torch_storage_id,
get_torch_storage_size,
save_torch_model,
save_torch_state_dict,
split_torch_state_dict_into_shards,
)
8 changes: 4 additions & 4 deletions src/huggingface_hub/serialization/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __post_init__(self):
def split_state_dict_into_shards_factory(
state_dict: Dict[str, TensorT],
*,
get_tensor_size: TensorSizeFn_T,
get_storage_size: TensorSizeFn_T,
filename_pattern: str,
get_storage_id: StorageIDFn_T = lambda tensor: None,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
Expand All @@ -72,8 +72,8 @@ def split_state_dict_into_shards_factory(
Args:
state_dict (`Dict[str, Tensor]`):
The state dictionary to save.
get_tensor_size (`Callable[[Tensor], int]`):
A function that returns the size of a tensor in bytes.
get_storage_size (`Callable[[Tensor], int]`):
A function that returns the size of a tensor when saved on disk in bytes.
get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
Expand Down Expand Up @@ -117,7 +117,7 @@ def split_state_dict_into_shards_factory(
storage_id_to_tensors[storage_id] = [key]

# Compute tensor size
tensor_size = get_tensor_size(tensor)
tensor_size = get_storage_size(tensor)

# If this tensor is bigger than the maximal size, we put it in its own shard
if tensor_size > max_shard_size:
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/serialization/_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def split_tf_state_dict_into_shards(
state_dict,
max_shard_size=max_shard_size,
filename_pattern=filename_pattern,
get_tensor_size=get_tensor_size,
get_storage_size=get_tf_storage_size,
)


def get_tensor_size(tensor: "tf.Tensor") -> int:
def get_tf_storage_size(tensor: "tf.Tensor") -> int:
# Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
# Better to overestimate than underestimate.
return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))
Expand Down
Loading

0 comments on commit dfe72d0

Please sign in to comment.