Skip to content

Commit 6bf994c

Browse files
committed
Merge branch 'main' of https://github.com/kohya-ss/sd-scripts into dev
2 parents 844b4fd + 852481e commit 6bf994c

10 files changed

+474
-21
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ wd14_tagger_model
88
.DS_Store
99
locon
1010
gui-user.bat
11-
gui-user.ps1
11+
gui-user.ps1
12+
.vscode
13+
wandb

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,14 @@ This will store a backup file with your current locally installed pip packages a
308308
* 2023/04/24 (v21.5.6)
309309
- Fix triton error
310310
- Fix issue with merge lora path with spaces
311+
- Added support for logging to wandb. Please refer to PR #428. Thank you p1atdev!
312+
- wandb installation is required. Please install it with pip install wandb. Login to wandb with wandb login command, or set --wandb_api_key option for automatic login.
313+
- Please let me know if you find any bugs as the test is not complete.
314+
- You can automatically login to wandb by setting the --wandb_api_key option. Please be careful with the handling of API Key. PR #435 Thank you Linaqruf!
315+
- Improved the behavior of --debug_dataset on non-Windows environments. PR #429 Thank you tsukimiya!
316+
- Fixed --face_crop_aug option not working in Fine tuning method.
317+
- Prepared code to use any upscaler in gen_img_diffusers.py.
318+
- Fixed to log to TensorBoard when --logging_dir is specified and --log_with is not specified.
311319
* 2023/04/22 (v21.5.5)
312320
- Update LoRA merge GUI to support SD checkpoint merge and up to 4 LoRA merging
313321
- Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi!

fine_tune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
260260
)
261261

262262
if accelerator.is_main_process:
263-
accelerator.init_trackers("finetuning")
263+
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
264264

265265
for epoch in range(num_train_epochs):
266266
print(f"epoch {epoch+1}/{num_train_epochs}")

gen_img_diffusers.py

+57-7
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,7 @@ def __call__(
945945

946946
# encode the init image into latents and scale the latents
947947
init_image = init_image.to(device=self.device, dtype=latents_dtype)
948-
if init_image.size()[2:] == (height // 8, width // 8):
948+
if init_image.size()[1:] == (height // 8, width // 8):
949949
init_latents = init_image
950950
else:
951951
if vae_batch_size >= batch_size:
@@ -1015,7 +1015,7 @@ def __call__(
10151015
if self.control_nets:
10161016
if reginonal_network:
10171017
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
1018-
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
1018+
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
10191019
else:
10201020
text_emb_last = text_embeddings
10211021
noise_pred = original_control_net.call_unet_and_control_net(
@@ -2318,6 +2318,22 @@ def __getattr__(self, item):
23182318
else:
23192319
networks = []
23202320

2321+
# upscalerの指定があれば取得する
2322+
upscaler = None
2323+
if args.highres_fix_upscaler:
2324+
print("import upscaler module:", args.highres_fix_upscaler)
2325+
imported_module = importlib.import_module(args.highres_fix_upscaler)
2326+
2327+
us_kwargs = {}
2328+
if args.highres_fix_upscaler_args:
2329+
for net_arg in args.highres_fix_upscaler_args.split(";"):
2330+
key, value = net_arg.split("=")
2331+
us_kwargs[key] = value
2332+
2333+
print("create upscaler")
2334+
upscaler = imported_module.create_upscaler(**us_kwargs)
2335+
upscaler.to(dtype).to(device)
2336+
23212337
# ControlNetの処理
23222338
control_nets: List[ControlNetInfo] = []
23232339
if args.control_net_models:
@@ -2590,7 +2606,7 @@ def resize_images(imgs, size):
25902606
np_mask = np_mask[:, :, i]
25912607
size = np_mask.shape
25922608
else:
2593-
np_mask = np.full(size, 255, dtype=np.uint8)
2609+
np_mask = np.full(size, 255, dtype=np.uint8)
25942610
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
25952611
network.set_region(i, i == len(networks) - 1, mask)
25962612
mask_images = None
@@ -2639,6 +2655,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
26392655
# highres_fixの処理
26402656
if highres_fix and not highres_1st:
26412657
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
2658+
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
2659+
26422660
print("process 1st stage")
26432661
batch_1st = []
26442662
for _, base, ext in batch:
@@ -2657,12 +2675,32 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
26572675
ext.network_muls,
26582676
ext.num_sub_prompts,
26592677
)
2660-
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
2678+
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
26612679
images_1st = process_batch(batch_1st, True, True)
26622680

26632681
# 2nd stageのバッチを作成して以下処理する
26642682
print("process 2nd stage")
2665-
if args.highres_fix_latents_upscaling:
2683+
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
2684+
2685+
if upscaler:
2686+
# upscalerを使って画像を拡大する
2687+
lowreso_imgs = None if is_1st_latent else images_1st
2688+
lowreso_latents = None if not is_1st_latent else images_1st
2689+
2690+
# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
2691+
batch_size = len(images_1st)
2692+
vae_batch_size = (
2693+
batch_size
2694+
if args.vae_batch_size is None
2695+
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
2696+
)
2697+
vae_batch_size = int(vae_batch_size)
2698+
images_1st = upscaler.upscale(
2699+
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
2700+
)
2701+
2702+
elif args.highres_fix_latents_upscaling:
2703+
# latentを拡大する
26662704
org_dtype = images_1st.dtype
26672705
if images_1st.dtype == torch.bfloat16:
26682706
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
@@ -2671,10 +2709,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
26712709
) # , antialias=True)
26722710
images_1st = images_1st.to(org_dtype)
26732711

2712+
else:
2713+
# 画像をLANCZOSで拡大する
2714+
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]
2715+
26742716
batch_2nd = []
26752717
for i, (bd, image) in enumerate(zip(batch, images_1st)):
2676-
if not args.highres_fix_latents_upscaling:
2677-
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
26782718
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
26792719
batch_2nd.append(bd_2nd)
26802720
batch = batch_2nd
@@ -3229,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser:
32293269
action="store_true",
32303270
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
32313271
)
3272+
parser.add_argument(
3273+
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
3274+
)
3275+
parser.add_argument(
3276+
"--highres_fix_upscaler_args",
3277+
type=str,
3278+
default=None,
3279+
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
3280+
)
3281+
32323282
parser.add_argument(
32333283
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
32343284
)

library/train_util.py

+59-8
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,10 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_
845845

846846
# 画像サイズはsizeより大きいのでリサイズする
847847
face_size = max(face_w, face_h)
848+
size = min(self.height, self.width) # 短いほう
848849
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
849-
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
850-
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
850+
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
851+
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
851852
if min_scale >= max_scale: # range指定がmin==max
852853
scale = min_scale
853854
else:
@@ -872,7 +873,7 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_
872873
else:
873874
# range指定があるときのみ、すこしだけランダムに(わりと適当)
874875
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
875-
if face_size > self.size // 10 and face_size >= 40:
876+
if face_size > size // 10 and face_size >= 40:
876877
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
877878

878879
p1 = max(0, min(p1, length - target_size))
@@ -1445,8 +1446,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
14451446
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
14461447
if os.name == "nt": # only windows
14471448
cv2.imshow("img", im)
1448-
k = cv2.waitKey()
1449-
cv2.destroyAllWindows()
1449+
k = cv2.waitKey()
1450+
cv2.destroyAllWindows()
14501451
if k == 27 or k == ord("s") or k == ord("e"):
14511452
break
14521453
steps += 1
@@ -2067,7 +2068,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
20672068
default=None,
20682069
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
20692070
)
2071+
parser.add_argument(
2072+
"--log_with",
2073+
type=str,
2074+
default=None,
2075+
choices=["tensorboard", "wandb", "all"],
2076+
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
2077+
)
20702078
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
2079+
parser.add_argument(
2080+
"--log_tracker_name",
2081+
type=str,
2082+
default=None,
2083+
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
2084+
)
2085+
parser.add_argument(
2086+
"--wandb_api_key",
2087+
type=str,
2088+
default=None,
2089+
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
2090+
)
20712091
parser.add_argument(
20722092
"--noise_offset",
20732093
type=float,
@@ -2288,7 +2308,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
22882308
args_dict = vars(args)
22892309

22902310
# remove unnecessary keys
2291-
for key in ["config_file", "output_config"]:
2311+
for key in ["config_file", "output_config", "wandb_api_key"]:
22922312
if key in args_dict:
22932313
del args_dict[key]
22942314

@@ -2732,13 +2752,32 @@ def load_tokenizer(args: argparse.Namespace):
27322752

27332753
def prepare_accelerator(args: argparse.Namespace):
27342754
if args.logging_dir is None:
2735-
log_with = None
27362755
logging_dir = None
27372756
else:
2738-
log_with = "tensorboard"
27392757
log_prefix = "" if args.log_prefix is None else args.log_prefix
27402758
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
27412759

2760+
if args.log_with is None:
2761+
if logging_dir is not None:
2762+
log_with = "tensorboard"
2763+
else:
2764+
log_with = None
2765+
else:
2766+
log_with = args.log_with
2767+
if log_with in ["tensorboard", "all"]:
2768+
if logging_dir is None:
2769+
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
2770+
if log_with in ["wandb", "all"]:
2771+
try:
2772+
import wandb
2773+
except ImportError:
2774+
raise ImportError("No wandb / wandb がインストールされていないようです")
2775+
if logging_dir is not None:
2776+
os.makedirs(logging_dir, exist_ok=True)
2777+
os.environ["WANDB_DIR"] = logging_dir
2778+
if args.wandb_api_key is not None:
2779+
wandb.login(key=args.wandb_api_key)
2780+
27422781
accelerator = Accelerator(
27432782
gradient_accumulation_steps=args.gradient_accumulation_steps,
27442783
mixed_precision=args.mixed_precision,
@@ -3197,6 +3236,18 @@ def sample_images(
31973236

31983237
image.save(os.path.join(save_dir, img_filename))
31993238

3239+
# wandb有効時のみログを送信
3240+
try:
3241+
wandb_tracker = accelerator.get_tracker("wandb")
3242+
try:
3243+
import wandb
3244+
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
3245+
raise ImportError("No wandb / wandb がインストールされていないようです")
3246+
3247+
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
3248+
except: # wandb 無効時
3249+
pass
3250+
32003251
# clear pipeline and cache to reduce vram usage
32013252
del pipeline
32023253
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)