Skip to content

Improve single loading file #4041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", 512)
image_size = kwargs.pop("image_size", None)
scheduler_type = kwargs.pop("scheduler_type", "pndm")
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
Expand Down
116 changes: 84 additions & 32 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AutoFeatureExtractor,
BertTokenizerFast,
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
Expand All @@ -48,7 +49,7 @@
PNDMScheduler,
UnCLIPScheduler,
)
from ...utils import is_omegaconf_available, is_safetensors_available, logging
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
from ...utils.import_utils import BACKENDS_MAPPING
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
Expand All @@ -57,6 +58,10 @@
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer


if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -770,11 +775,12 @@ def _copy_layers(hf_layers, pt_layers):


def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
text_model = (
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
if text_encoder is None
else text_encoder
)
if text_encoder is None:
config_name = "openai/clip-vit-large-patch14"
config = CLIPTextConfig.from_pretrained(config_name)

with init_empty_weights():
text_model = CLIPTextModel(config)

keys = list(checkpoint.keys())

Expand All @@ -787,7 +793,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
if key.startswith(prefix):
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]

text_model.load_state_dict(text_model_dict)
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)

return text_model

Expand Down Expand Up @@ -884,14 +891,26 @@ def convert_paint_by_example_checkpoint(checkpoint):
return model


def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
def convert_open_clip_checkpoint(
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
):
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
text_model = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
)
# text_model = CLIPTextModelWithProjection.from_pretrained(
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
# )
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)

with init_empty_weights():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to break downstream if you don't have accelerate and we probably need a if is_accelerate_available(): check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! I think we should actually just force the user here to install accelerate since this method only exists for PyTorch anyways and there is no harm in installing accelerate for PyTorch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we can avoid forcing an accelerate install ? We ship nightly HF libraries (transformers, diffusers) and then export out via torch-mlir to SHARK so we don't use PyTorch/Accelerate and would like to avoid adding the dependency to our shipping binaries if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in: #4132

text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)

keys = list(checkpoint.keys())

keys_to_ignore = []
if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
# make sure to remove all keys > 22
keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
keys_to_ignore += ["cond_stage_model.model.text_projection"]

text_model_dict = {}

if prefix + "text_projection" in checkpoint:
Expand All @@ -902,8 +921,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")

for key in keys:
# if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
# continue
if key in keys_to_ignore:
continue
if key[len(prefix) :] in textenc_conversion_map:
if key.endswith("text_projection"):
value = checkpoint[key].T
Expand Down Expand Up @@ -931,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):

text_model_dict[new_key] = checkpoint[key]

text_model.load_state_dict(text_model_dict)
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)

return text_model

Expand Down Expand Up @@ -1061,7 +1081,7 @@ def convert_controlnet_checkpoint(
def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
original_config_file: str = None,
image_size: int = 512,
image_size: Optional[int] = None,
prediction_type: str = None,
model_type: str = None,
extract_ema: bool = False,
Expand Down Expand Up @@ -1144,6 +1164,7 @@ def download_from_original_stable_diffusion_ckpt(
LDMTextToImagePipeline,
PaintByExamplePipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
Expand All @@ -1166,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt(
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])

from safetensors import safe_open
from safetensors.torch import load_file as safe_load

checkpoint = {}
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
checkpoint = safe_load(checkpoint_path, device="cpu")
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -1183,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
logger.warning("global_step key not found in model")
logger.debug("global_step key not found in model")
global_step = None

# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
Expand Down Expand Up @@ -1230,9 +1248,15 @@ def download_from_original_stable_diffusion_ckpt(
model_type = "SDXL"
else:
model_type = "SDXL-Refiner"
if image_size is None:
image_size = 1024

if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
elif num_in_channels is None:
num_in_channels = 4

original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels

if (
"parameterization" in original_config["model"]["params"]
Expand Down Expand Up @@ -1263,7 +1287,6 @@ def download_from_original_stable_diffusion_ckpt(
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000

if model_type in ["SDXL", "SDXL-Refiner"]:
image_size = 1024
scheduler_dict = {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
Expand All @@ -1279,7 +1302,6 @@ def download_from_original_stable_diffusion_ckpt(
}
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
scheduler_type = "euler"
vae_path = "stabilityai/sdxl-vae"
else:
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
Expand Down Expand Up @@ -1318,25 +1340,45 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config)
with init_empty_weights():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably need to qualify this with if is_accelerate_available(): check if we don't have accelerate

unet = UNet2DConditionModel(**unet_config)

converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
)
unet.load_state_dict(converted_unet_checkpoint)

for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)

# Convert the VAE model.
if vae_path is None:
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)

vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
if (
"model" in original_config
and "params" in original_config.model
and "scale_factor" in original_config.model.params
):
vae_scaling_factor = original_config.model.params.scale_factor
else:
vae_scaling_factor = 0.18215 # default SD scaling factor

vae_config["scaling_factor"] = vae_scaling_factor

with init_empty_weights():
vae = AutoencoderKL(**vae_config)

for param_name, param in converted_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else:
vae = AutoencoderKL.from_pretrained(vae_path)

if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}

text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")

if stable_unclip is None:
Expand Down Expand Up @@ -1469,7 +1511,12 @@ def download_from_original_stable_diffusion_ckpt(
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")

config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
)

pipe = StableDiffusionXLPipeline(
vae=vae,
Expand All @@ -1485,7 +1532,12 @@ def download_from_original_stable_diffusion_ckpt(
tokenizer = None
text_encoder = None
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.")

config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
)

pipe = StableDiffusionXLImg2ImgPipeline(
vae=vae,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
Expand Down Expand Up @@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image


class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion.

Expand Down
39 changes: 39 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import (
Expand Down Expand Up @@ -512,6 +515,42 @@ def test_stable_diffusion_simple_inpaint_ddim(self):

assert np.abs(expected_slice - image_slice).max() < 6e-4

def test_download_local(self):
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")

pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")

inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 1
image_out = pipe(**inputs).images[0]

assert image_out.shape == (512, 512, 3)

def test_download_ckpt_diff_format_is_same(self):
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"

pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")

inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 5
image_ckpt = pipe(**inputs).images[0]

pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")

inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 5
image = pipe(**inputs).images[0]

assert np.max(np.abs(image - image_ckpt)) < 1e-4


@nightly
@require_torch_gpu
Expand Down
Loading