Skip to content

[Core] Support single file from from_pretrained #6986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
1 change: 1 addition & 0 deletions src/diffusers/loaders/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,5 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
if torch_dtype is not None:
vae = vae.to(torch_dtype)

vae.eval()
return vae
1 change: 1 addition & 0 deletions src/diffusers/loaders/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,5 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)

controlnet.eval()
return controlnet
447 changes: 230 additions & 217 deletions src/diffusers/models/modeling_utils.py

Large diffs are not rendered by default.

587 changes: 314 additions & 273 deletions src/diffusers/pipelines/pipeline_utils.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .. import __version__
from .constants import (
_ACCEPTED_SINGLE_FILE_FORMATS,
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
Expand Down Expand Up @@ -83,7 +84,7 @@
is_xformers_available,
requires_backends,
)
from .loading_utils import load_image
from .loading_utils import is_single_file_checkpoint, load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")

# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
Expand Down
18 changes: 18 additions & 0 deletions src/diffusers/utils/loading_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
import os
from typing import Callable, Union
from urllib.parse import urlparse

import PIL.Image
import PIL.ImageOps
import requests

from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS


def is_single_file_checkpoint(filepath):
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True

filepath = str(filepath)
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
if is_valid_url(filepath):
return True
elif os.path.isfile(filepath):
return True
return False


def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
Expand Down