Skip to content

Integration of from_pretrained and from_single_file #10208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 52 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d8dc461
update
suzukimain Dec 13, 2024
b239d4f
update
suzukimain Dec 13, 2024
abf20e1
update
suzukimain Dec 13, 2024
67b1f40
update
suzukimain Dec 13, 2024
4132129
update
suzukimain Dec 13, 2024
ef77efb
Added pipeline mappings for loading a single file checkpoint
suzukimain Dec 13, 2024
349726f
update
suzukimain Dec 13, 2024
5e9d485
fix
suzukimain Dec 13, 2024
2475b86
update
suzukimain Dec 13, 2024
e791d62
Merge branch 'huggingface:main' into load_Method_Integration
suzukimain Dec 13, 2024
4788002
fix
suzukimain Dec 13, 2024
fdf5020
fix
suzukimain Dec 13, 2024
b9449df
fix
suzukimain Dec 13, 2024
8c25f96
Fixed duplicate checkpoint loading
suzukimain Dec 13, 2024
cb525df
Merge branch 'main' into loading_method_integration
suzukimain Dec 13, 2024
f545353
Merge branch 'main' into loading_method_integration
suzukimain Dec 13, 2024
ef96367
Merge branch 'main' into loading_method_integration
suzukimain Dec 15, 2024
891b11c
Code formatting
suzukimain Dec 15, 2024
e072f90
Merge branch 'main' of https://github.com/huggingface/diffusers into …
suzukimain Dec 17, 2024
7b22b18
fix
suzukimain Dec 17, 2024
eb09cf0
Add keys and sort alphabetically
suzukimain Dec 17, 2024
b9b62ea
Merge branch 'main' into loading_method_integration
suzukimain Dec 18, 2024
ca764d4
Merge branch 'main' into loading_method_integration
suzukimain Dec 19, 2024
0874ab6
Merge branch 'main' of https://github.com/huggingface/diffusers into …
suzukimain Dec 23, 2024
6488693
Added `flux-depth`, `flux-fill`, and `mochi-1-preview` to the pipelin…
suzukimain Dec 23, 2024
970a7ff
Merge branch 'main' of https://github.com/huggingface/diffusers into …
suzukimain Jan 3, 2025
75e510d
update
suzukimain Jan 3, 2025
118ce7f
Merge branch 'main' into loading_method_integration
suzukimain Jan 8, 2025
0e41022
update
suzukimain Jan 10, 2025
384bd7f
update
suzukimain Jan 10, 2025
56403a6
update
suzukimain Jan 10, 2025
e9a48fb
update
suzukimain Jan 10, 2025
0ec85e0
update
suzukimain Jan 10, 2025
41f960f
update
suzukimain Jan 10, 2025
8c10847
update
suzukimain Jan 10, 2025
6a79268
fix
suzukimain Jan 10, 2025
8e845c2
make quality
suzukimain Jan 10, 2025
6b252aa
update
suzukimain Jan 10, 2025
d10e3f1
update
suzukimain Jan 10, 2025
ceae059
update
suzukimain Jan 10, 2025
18fce79
update
suzukimain Jan 10, 2025
a5604da
update
suzukimain Jan 10, 2025
69ba240
update
suzukimain Jan 10, 2025
d82b9a9
update
suzukimain Jan 10, 2025
b0f18ae
make quality
suzukimain Jan 10, 2025
8e58877
Merge https://github.com/huggingface/diffusers into loading_method_in…
suzukimain Jan 10, 2025
32ce3a6
Merge branch 'main' of https://github.com/huggingface/diffusers into …
suzukimain Jan 10, 2025
bf0f979
Integrate from_pretrained with from_single_file functionality
suzukimain Jan 10, 2025
0a7aa0e
Merge branch 'main' of https://github.com/huggingface/diffusers into …
suzukimain Jan 10, 2025
6b41c16
Merge branch 'main' of https://github.com/huggingface/diffusers into …
suzukimain Mar 1, 2025
b480850
Update pipeline mapping
suzukimain Mar 1, 2025
2f7a245
Fix and make quality
suzukimain Mar 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
checkpoint (`dict`, *optional*):
The loaded state dictionary of the model.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
Expand Down Expand Up @@ -362,6 +364,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
disable_mmap = kwargs.pop("disable_mmap", False)
checkpoint = kwargs.pop("checkpoint", None)

is_legacy_loading = False

Expand All @@ -386,18 +389,19 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:

from ..pipelines.pipeline_utils import _get_pipeline_class

pipeline_class = _get_pipeline_class(cls, config=None)

checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
pipeline_class = _get_pipeline_class(cls, class_name=cls.__name__, config=None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't pass the class_name when loading a single file checkpoint in DiffusionPipeline, it will result in a TypeError.


if checkpoint is None:
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)

if config is None:
config = fetch_diffusers_config(checkpoint)
Expand Down Expand Up @@ -480,6 +484,11 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)

from diffusers import pipelines

# remove `null` components
Expand Down
80 changes: 80 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2857,3 +2857,83 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
converted_state_dict[diffusers_key] = checkpoint.pop(key)

return converted_state_dict


def get_keyword_types(keyword):
r"""
Determine the type and loading method for a given keyword.

Parameters:
keyword (`str`):
The input keyword to classify.

Returns:
`dict`: A dictionary containing the model format, loading method,
and various types and extra types flags.
"""

# Initialize the status dictionary with default values
status = {
"checkpoint_format": None,
"loading_method": None,
"type": {
"other": False,
"hf_url": False,
"hf_repo": False,
"civitai_url": False,
"local": False,
},
"extra_type": {
"url": False,
"missing_model_index": None,
},
}

# Check if the keyword is an HTTP or HTTPS URL
status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword))

# Check if the keyword is a file
if os.path.isfile(keyword):
status["type"]["local"] = True
status["checkpoint_format"] = "single_file"
status["loading_method"] = "from_single_file"

# Check if the keyword is a directory
elif os.path.isdir(keyword):
status["type"]["local"] = True
status["checkpoint_format"] = "diffusers"
status["loading_method"] = "from_pretrained"
if not os.path.exists(os.path.join(keyword, "model_index.json")):
status["extra_type"]["missing_model_index"] = True

# Check if the keyword is a Civitai URL
elif keyword.startswith("https://civitai.com/"):
status["type"]["civitai_url"] = True
status["checkpoint_format"] = "single_file"
status["loading_method"] = None

# Check if the keyword starts with any valid URL prefixes
elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES):
repo_id, weights_name = _extract_repo_id_and_weights_name(keyword)
if weights_name:
status["type"]["hf_url"] = True
status["checkpoint_format"] = "single_file"
status["loading_method"] = "from_single_file"
else:
status["type"]["hf_repo"] = True
status["checkpoint_format"] = "diffusers"
status["loading_method"] = "from_pretrained"

# Check if the keyword matches a Hugging Face repository format
elif re.match(r"^[^/]+/[^/]+$", keyword):
status["type"]["hf_repo"] = True
status["checkpoint_format"] = "diffusers"
status["loading_method"] = "from_pretrained"

# If none of the above apply
else:
status["type"]["other"] = True
status["checkpoint_format"] = None
status["loading_method"] = None

return status
Loading