Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/kohya-ss/sd-scripts into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Apr 24, 2023
2 parents 844b4fd + 852481e commit 6bf994c
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 21 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ wd14_tagger_model
.DS_Store
locon
gui-user.bat
gui-user.ps1
gui-user.ps1
.vscode
wandb
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,14 @@ This will store a backup file with your current locally installed pip packages a
* 2023/04/24 (v21.5.6)
- Fix triton error
- Fix issue with merge lora path with spaces
- Added support for logging to wandb. Please refer to PR #428. Thank you p1atdev!
- wandb installation is required. Please install it with pip install wandb. Login to wandb with wandb login command, or set --wandb_api_key option for automatic login.
- Please let me know if you find any bugs as the test is not complete.
- You can automatically login to wandb by setting the --wandb_api_key option. Please be careful with the handling of API Key. PR #435 Thank you Linaqruf!
- Improved the behavior of --debug_dataset on non-Windows environments. PR #429 Thank you tsukimiya!
- Fixed --face_crop_aug option not working in Fine tuning method.
- Prepared code to use any upscaler in gen_img_diffusers.py.
- Fixed to log to TensorBoard when --logging_dir is specified and --log_with is not specified.
* 2023/04/22 (v21.5.5)
- Update LoRA merge GUI to support SD checkpoint merge and up to 4 LoRA merging
- Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi!
Expand Down
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

if accelerator.is_main_process:
accelerator.init_trackers("finetuning")
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
Expand Down
64 changes: 57 additions & 7 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def __call__(

# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[2:] == (height // 8, width // 8):
if init_image.size()[1:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def __call__(
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
Expand Down Expand Up @@ -2318,6 +2318,22 @@ def __getattr__(self, item):
else:
networks = []

# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
imported_module = importlib.import_module(args.highres_fix_upscaler)

us_kwargs = {}
if args.highres_fix_upscaler_args:
for net_arg in args.highres_fix_upscaler_args.split(";"):
key, value = net_arg.split("=")
us_kwargs[key] = value

print("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)

# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
Expand Down Expand Up @@ -2590,7 +2606,7 @@ def resize_images(imgs, size):
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
Expand Down Expand Up @@ -2639,6 +2655,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
# highres_fixの処理
if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling

print("process 1st stage")
batch_1st = []
for _, base, ext in batch:
Expand All @@ -2657,12 +2675,32 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)

# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
if args.highres_fix_latents_upscaling:
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height

if upscaler:
# upscalerを使って画像を拡大する
lowreso_imgs = None if is_1st_latent else images_1st
lowreso_latents = None if not is_1st_latent else images_1st

# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
batch_size = len(images_1st)
vae_batch_size = (
batch_size
if args.vae_batch_size is None
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
)
vae_batch_size = int(vae_batch_size)
images_1st = upscaler.upscale(
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
)

elif args.highres_fix_latents_upscaling:
# latentを拡大する
org_dtype = images_1st.dtype
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
Expand All @@ -2671,10 +2709,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
) # , antialias=True)
images_1st = images_1st.to(org_dtype)

else:
# 画像をLANCZOSで拡大する
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]

batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)):
if not args.highres_fix_latents_upscaling:
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
Expand Down Expand Up @@ -3229,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
)
parser.add_argument(
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
)
parser.add_argument(
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)

parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
)
Expand Down
67 changes: 59 additions & 8 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,10 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_

# 画像サイズはsizeより大きいのでリサイズする
face_size = max(face_w, face_h)
size = min(self.height, self.width) # 短いほう
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
if min_scale >= max_scale: # range指定がmin==max
scale = min_scale
else:
Expand All @@ -872,7 +873,7 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_
else:
# range指定があるときのみ、すこしだけランダムに(わりと適当)
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
if face_size > self.size // 10 and face_size >= 40:
if face_size > size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)

p1 = max(0, min(p1, length - target_size))
Expand Down Expand Up @@ -1445,8 +1446,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
Expand Down Expand Up @@ -2067,7 +2068,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
)
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument(
"--log_tracker_name",
type=str,
default=None,
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
)
parser.add_argument(
"--wandb_api_key",
type=str,
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
)
parser.add_argument(
"--noise_offset",
type=float,
Expand Down Expand Up @@ -2288,7 +2308,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
args_dict = vars(args)

# remove unnecessary keys
for key in ["config_file", "output_config"]:
for key in ["config_file", "output_config", "wandb_api_key"]:
if key in args_dict:
del args_dict[key]

Expand Down Expand Up @@ -2732,13 +2752,32 @@ def load_tokenizer(args: argparse.Namespace):

def prepare_accelerator(args: argparse.Namespace):
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())

if args.log_with is None:
if logging_dir is not None:
log_with = "tensorboard"
else:
log_with = None
else:
log_with = args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None:
os.makedirs(logging_dir, exist_ok=True)
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
Expand Down Expand Up @@ -3197,6 +3236,18 @@ def sample_images(

image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass

# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
Expand Down
Loading

0 comments on commit 6bf994c

Please sign in to comment.