Skip to content

Commit

Permalink
support for controlnet in sample output
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed Jun 1, 2023
1 parent 5f71c48 commit 397cbe1
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 31 deletions.
2 changes: 1 addition & 1 deletion library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
loss = loss * snr_weight
return loss

Expand Down
87 changes: 84 additions & 3 deletions library/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Callable, List, Optional, Union

import numpy as np
import PIL
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -426,6 +426,59 @@ def preprocess_mask(mask, scale_factor=8):
return mask


def prepare_controlnet_image(
image: PIL.Image.Image,
width: int,
height: int,
batch_size: int,
num_images_per_prompt: int,
device: torch.device,
dtype: torch.dtype,
do_classifier_free_guidance: bool = False,
guess_mode: bool = False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]

if isinstance(image[0], PIL.Image.Image):
images = []

for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)

image = images

image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)

image_batch_size = image.shape[0]

if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt

image = image.repeat_interleave(repeat_by, dim=0)

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)

return image

class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
Expand Down Expand Up @@ -464,10 +517,10 @@ def __init__(
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
clip_skip: int = 1,
):
super().__init__(
vae=vae,
Expand Down Expand Up @@ -707,6 +760,8 @@ def __call__(
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
controlnet=None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: int = 1,
Expand Down Expand Up @@ -767,6 +822,11 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
controlnet (`diffusers.ControlNetModel`, *optional*):
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
inference.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
Expand All @@ -785,6 +845,9 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if controlnet is not None and controlnet_image is None:
raise ValueError("controlnet_image must be provided if controlnet is not None.")

# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
Expand Down Expand Up @@ -824,6 +887,10 @@ def __call__(
else:
mask = None

if controlnet_image is not None:
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)


# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
Expand Down Expand Up @@ -851,8 +918,22 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample

# perform guidance
if do_classifier_free_guidance:
Expand Down
28 changes: 25 additions & 3 deletions library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):

return new_state_dict


def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
def controlnet_conversion_map():
unet_conversion_map = [
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
Expand Down Expand Up @@ -792,6 +791,12 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
sd_prefix = f"zero_convs.{i}.0."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))

return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer


def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()

mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map:
mapping[diffusers_name] = sd_name
Expand All @@ -807,6 +812,23 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict

def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()

mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map:
mapping[sd_name] = diffusers_name
for k, v in mapping.items():
for sd_part, diffusers_part in unet_conversion_map_layer:
v = v.replace(sd_part, diffusers_part)
mapping[k] = v
for k, v in mapping.items():
if "resnets" in v:
for sd_part, diffusers_part in unet_conversion_map_resnet:
v = v.replace(sd_part, diffusers_part)
mapping[k] = v
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict

# ================#
# VAE Conversion #
Expand Down Expand Up @@ -928,7 +950,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):


# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)

# Convert the UNet2DConditionModel model.
Expand Down
29 changes: 21 additions & 8 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,7 +1670,6 @@ def __getitem__(self, index):
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)
conditioning_images = torch.stack(conditioning_images)

example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
Expand All @@ -1695,7 +1694,7 @@ def __getitem__(self, index):
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]

example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()

return example

Expand Down Expand Up @@ -3329,13 +3328,13 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype


def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2)
else:
# Diffusers model is loaded to CPU
print(f"load Diffusers pretrained models: {name_or_path}")
Expand Down Expand Up @@ -3363,14 +3362,14 @@ def transform_if_model_is_DDP(text_encoder, unet, network=None):
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)


def load_target_model(args, weight_dtype, accelerator):
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
)

# work on low-ram device
Expand Down Expand Up @@ -3684,7 +3683,7 @@ def save_sd_model_on_train_end(


def sample_images(
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None
):
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
Expand Down Expand Up @@ -3774,10 +3773,10 @@ def sample_images(
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
clip_skip=args.clip_skip,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(device)

Expand All @@ -3800,6 +3799,7 @@ def sample_images(
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
Expand All @@ -3814,6 +3814,7 @@ def sample_images(
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
Expand Down Expand Up @@ -3846,6 +3847,12 @@ def sample_images(
negative_prompt = m.group(1)
continue

m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue


except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
Expand All @@ -3859,6 +3866,10 @@ def sample_images(
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])

if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)

height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
Expand All @@ -3874,6 +3885,8 @@ def sample_images(
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
).images[0]

ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
Expand Down
Loading

0 comments on commit 397cbe1

Please sign in to comment.