From 3a2d897bb49da571ab3b1f4b3610a980803cefdf Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 18 Sep 2024 01:45:18 +0530 Subject: [PATCH] set max_shard_size to None for pipeline save_pretrained (#9447) * update default max_shard_size * add None check to fix tests --------- Co-authored-by: YiYi Xu --- src/diffusers/pipelines/pipeline_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index dffd49cb0ce72..ccd1c9485d0e4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -189,7 +189,7 @@ def save_pretrained( save_directory: Union[str, os.PathLike], safe_serialization: bool = True, variant: Optional[str] = None, - max_shard_size: Union[int, str] = "10GB", + max_shard_size: Optional[Union[int, str]] = None, push_to_hub: bool = False, **kwargs, ): @@ -205,7 +205,7 @@ class implements both a save and loading method. The pipeline is easily reloaded Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. - max_shard_size (`int` or `str`, defaults to `"10GB"`): + max_shard_size (`int` or `str`, defaults to `None`): The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain @@ -293,7 +293,8 @@ def is_saveable_module(name, value): save_kwargs["safe_serialization"] = safe_serialization if save_method_accept_variant: save_kwargs["variant"] = variant - if save_method_accept_max_shard_size: + if save_method_accept_max_shard_size and max_shard_size is not None: + # max_shard_size is expected to not be None in ModelMixin save_kwargs["max_shard_size"] = max_shard_size save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)