Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

P+: Extended Textual Conditioning in Text-to-Image Generation #327

Merged
merged 3 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
disabled sampling (for now)
  • Loading branch information
Jakaline-dev committed Mar 29, 2023
commit 24e3d4b4642673cebd552f1a9cbd3d99eac969a2
17 changes: 7 additions & 10 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,29 +781,26 @@ def __call__(
text_embeddings_concat = []
for layer in ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
layer=layer,
**kwargs,
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
layer=layer,
**kwargs,
)
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
else:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))


text_embeddings = torch.stack(text_embeddings_concat)
else:
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])

text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
Expand Down
46 changes: 29 additions & 17 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import toml
from multiprocessing import Value

from tqdm import tqdm
import torch
Expand All @@ -17,7 +18,8 @@
ConfigSanitizer,
BlueprintGenerator,
)

import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI

imagenet_templates_small = [
Expand Down Expand Up @@ -73,10 +75,6 @@
]


def collate_fn(examples):
return examples[0]


def train(args):
if args.output_name is None:
args.output_name = args.token_string
Expand Down Expand Up @@ -195,6 +193,10 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value('i',0)
current_step = Value('i',0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)

# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
Expand All @@ -207,14 +209,14 @@ def train(args):
train_dataset_group.add_replacement("", captions)

if args.num_vectors_per_token > 1:
prompt_replacement = [args.token_string, replace_to]
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = [args.token_string, replace_to]
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None

Expand Down Expand Up @@ -264,16 +266,19 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)

# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)

Expand Down Expand Up @@ -345,12 +350,14 @@ def train(args):

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
current_epoch.value = epoch+1

text_encoder.train()

loss_total = 0

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
Expand Down Expand Up @@ -391,6 +398,9 @@ def train(args):

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
Expand All @@ -416,10 +426,10 @@ def train(args):
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1

train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )

current_loss = loss.detach().item()
if args.logging_dir is not None:
Expand Down Expand Up @@ -466,9 +476,10 @@ def remove_old_func(old_epoch_no):
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)

train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )

# end of epoch

Expand Down Expand Up @@ -543,6 +554,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument(
"--save_model_as",
Expand Down