|
2 | 2 | import os
|
3 | 3 | import json
|
4 | 4 |
|
| 5 | +from pathlib import Path |
| 6 | +from typing import List |
5 | 7 | from tqdm import tqdm
|
6 | 8 | import numpy as np
|
7 | 9 | from PIL import Image
|
@@ -41,22 +43,31 @@ def get_latents(vae, images, weight_dtype):
|
41 | 43 | return latents
|
42 | 44 |
|
43 | 45 |
|
44 |
| -def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip): |
| 46 | +def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): |
45 | 47 | if is_full_path:
|
46 | 48 | base_name = os.path.splitext(os.path.basename(image_key))[0]
|
| 49 | + relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) |
47 | 50 | else:
|
48 | 51 | base_name = image_key
|
| 52 | + relative_path = "" |
| 53 | + |
49 | 54 | if flip:
|
50 | 55 | base_name += '_flip'
|
51 |
| - return os.path.join(data_dir, base_name) |
| 56 | + |
| 57 | + if recursive and relative_path: |
| 58 | + return os.path.join(data_dir, relative_path, base_name) |
| 59 | + else: |
| 60 | + return os.path.join(data_dir, base_name) |
| 61 | + |
52 | 62 |
|
53 | 63 |
|
54 | 64 | def main(args):
|
55 | 65 | # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
56 | 66 | if args.bucket_reso_steps % 8 > 0:
|
57 | 67 | print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
58 | 68 |
|
59 |
| - image_paths = train_util.glob_images(args.train_data_dir) |
| 69 | + train_data_dir_path = Path(args.train_data_dir) |
| 70 | + image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] |
60 | 71 | print(f"found {len(image_paths)} images.")
|
61 | 72 |
|
62 | 73 | if os.path.exists(args.in_json):
|
@@ -99,20 +110,20 @@ def process_batch(is_last):
|
99 | 110 | f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
100 | 111 |
|
101 | 112 | for (image_key, _), latent in zip(bucket, latents):
|
102 |
| - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) |
| 113 | + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) |
103 | 114 | np.savez(npz_file_name, latent)
|
104 | 115 |
|
105 | 116 | # flip
|
106 | 117 | if args.flip_aug:
|
107 | 118 | latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
108 | 119 |
|
109 | 120 | for (image_key, _), latent in zip(bucket, latents):
|
110 |
| - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) |
| 121 | + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) |
111 | 122 | np.savez(npz_file_name, latent)
|
112 | 123 | else:
|
113 | 124 | # remove existing flipped npz
|
114 | 125 | for image_key, _ in bucket:
|
115 |
| - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" |
| 126 | + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" |
116 | 127 | if os.path.isfile(npz_file_name):
|
117 | 128 | print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
118 | 129 | os.remove(npz_file_name)
|
@@ -169,9 +180,9 @@ def process_batch(is_last):
|
169 | 180 |
|
170 | 181 | # 既に存在するファイルがあればshapeを確認して同じならskipする
|
171 | 182 | if args.skip_existing:
|
172 |
| - npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] |
| 183 | + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] |
173 | 184 | if args.flip_aug:
|
174 |
| - npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") |
| 185 | + npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz") |
175 | 186 |
|
176 | 187 | found = True
|
177 | 188 | for npz_file in npz_files:
|
@@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
256 | 267 | help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
257 | 268 | parser.add_argument("--skip_existing", action="store_true",
|
258 | 269 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
| 270 | + parser.add_argument("--recursive", action="store_true", |
| 271 | + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") |
259 | 272 |
|
260 | 273 | return parser
|
261 | 274 |
|
|
0 commit comments