Skip to content

Commit

Permalink
Merge pull request #1655 from kohya-ss/sdxl-ctrl-net
Browse files Browse the repository at this point in the history
ControlNet training for SDXL
  • Loading branch information
kohya-ss authored Oct 11, 2024
2 parents 83e3048 + 035c4a8 commit 43bfeea
Show file tree
Hide file tree
Showing 19 changed files with 1,568 additions and 232 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ The command to install PyTorch is as follows:

### Recent Updates

Oct 11, 2024:
- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`.
- For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file).
- The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`.
- If you want to generate sample images, specify the control image as `--cn path/to/control/image`.
- The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI.
- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`.
- The default is `False`. It is same as before, and the parentheses are used as normal text.
- If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`.

Oct 6, 2024:
- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified.
- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format.
Expand Down
2 changes: 1 addition & 1 deletion docs/train_lllite_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ for img_file in img_files:

### Creating a dataset configuration file

You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.

```toml
[general]
Expand Down
17 changes: 6 additions & 11 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,22 +366,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
# TODO move to strategy_sd.py
encoder_hidden_states = get_weighted_text_embeddings(
tokenize_strategy.tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
else:
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
Expand Down
89 changes: 74 additions & 15 deletions gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
)
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
from accelerate import init_empty_weights
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
Expand All @@ -58,6 +58,7 @@
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.sdxl_original_control_net import SdxlControlNet
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
Expand Down Expand Up @@ -352,8 +353,8 @@ def __init__(
self.token_replacements_list.append({})

# ControlNet
self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5
self.control_net_lllites: List[ControlNetLLLite] = []
self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない

self.gradual_latent: GradualLatent = None
Expand Down Expand Up @@ -542,7 +543,7 @@ def __call__(
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])

if self.control_net_lllites:
if self.control_net_lllites or (self.control_nets and self.is_sdxl):
# ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
Expand Down Expand Up @@ -731,7 +732,12 @@ def __call__(
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1

if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
if not self.is_sdxl:
guided_hints = original_control_net.get_guided_hints(
self.control_nets, num_latent_input, batch_size, clip_guide_images
)
else:
clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1]
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)

if self.control_net_lllites:
Expand Down Expand Up @@ -793,7 +799,7 @@ def __call__(
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo
# disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo
if self.control_net_lllites:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
Expand All @@ -802,9 +808,16 @@ def __call__(
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
if self.control_nets and self.is_sdxl:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
each_control_net_enabled[j] = False

# predict the noise residual
if self.control_nets and self.control_net_enabled:
if self.control_nets and self.control_net_enabled and not self.is_sdxl:
if regional_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
Expand All @@ -823,6 +836,31 @@ def __call__(
text_embeddings,
text_emb_last,
).sample
elif self.control_nets:
input_resi_add_list = []
mid_add_list = []
for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled):
if not enbld:
continue
input_resi_add, mid_add = control_net(
latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images
)
input_resi_add_list.append(input_resi_add)
mid_add_list.append(mid_add)
if len(input_resi_add_list) == 0:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
if len(input_resi_add_list) > 1:
# get mean of input_resi_add_list and mid_add_list
input_resi_add_mean = []
for i in range(len(input_resi_add_list[0])):
input_resi_add_mean.append(
torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0))
)
input_resi_add = input_resi_add_mean
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)

noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
elif self.is_sdxl:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
Expand Down Expand Up @@ -1827,16 +1865,37 @@ def __getattr__(self, item):
upscaler.to(dtype).to(device)

# ControlNetの処理
control_nets: List[ControlNetInfo] = []
control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
if not is_sdxl:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]

ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
else:
for i, model_file in enumerate(args.control_net_models):
multiplier = (
1.0
if not args.control_net_multipliers or len(args.control_net_multipliers) <= i
else args.control_net_multipliers[i]
)
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]

logger.info(f"loading SDXL ControlNet: {model_file}")
from safetensors.torch import load_file

state_dict = load_file(model_file)

ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}")
with init_empty_weights():
control_net = SdxlControlNet(multiplier=multiplier)
control_net.load_state_dict(state_dict)
control_net.to(dtype).to(device)
control_nets.append((control_net, ratio))

control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:
Expand Down
Loading

0 comments on commit 43bfeea

Please sign in to comment.