Skip to content

Commit 01ebfc4

Browse files
authored
Merge pull request #400 from Linaqruf/main
Recursive support for captioning/tagging scripts
2 parents 6d5f847 + d5263d4 commit 01ebfc4

5 files changed

+228
-178
lines changed

finetune/make_captions.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import random
66

7+
from pathlib import Path
78
from PIL import Image
89
from tqdm import tqdm
910
import numpy as np
@@ -72,7 +73,8 @@ def main(args):
7273
os.chdir('finetune')
7374

7475
print(f"load images from {args.train_data_dir}")
75-
image_paths = train_util.glob_images(args.train_data_dir)
76+
train_data_dir_path = Path(args.train_data_dir)
77+
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
7678
print(f"found {len(image_paths)} images.")
7779

7880
print(f"loading BLIP caption: {args.caption_weights}")
@@ -152,7 +154,8 @@ def setup_parser() -> argparse.ArgumentParser:
152154
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
153155
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
154156
parser.add_argument("--debug", action="store_true", help="debug mode")
155-
157+
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")
158+
156159
return parser
157160

158161

finetune/make_captions_by_git.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import re
44

5+
from pathlib import Path
56
from PIL import Image
67
from tqdm import tqdm
78
import torch
@@ -65,7 +66,8 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
6566
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
6667

6768
print(f"load images from {args.train_data_dir}")
68-
image_paths = train_util.glob_images(args.train_data_dir)
69+
train_data_dir_path = Path(args.train_data_dir)
70+
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
6971
print(f"found {len(image_paths)} images.")
7072

7173
# できればcacheに依存せず明示的にダウンロードしたい
@@ -140,7 +142,8 @@ def setup_parser() -> argparse.ArgumentParser:
140142
parser.add_argument("--remove_words", action="store_true",
141143
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
142144
parser.add_argument("--debug", action="store_true", help="debug mode")
143-
145+
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")
146+
144147
return parser
145148

146149

finetune/prepare_buckets_latents.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import os
33
import json
44

5+
from pathlib import Path
6+
from typing import List
57
from tqdm import tqdm
68
import numpy as np
79
from PIL import Image
@@ -41,22 +43,31 @@ def get_latents(vae, images, weight_dtype):
4143
return latents
4244

4345

44-
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
46+
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
4547
if is_full_path:
4648
base_name = os.path.splitext(os.path.basename(image_key))[0]
49+
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
4750
else:
4851
base_name = image_key
52+
relative_path = ""
53+
4954
if flip:
5055
base_name += '_flip'
51-
return os.path.join(data_dir, base_name)
56+
57+
if recursive and relative_path:
58+
return os.path.join(data_dir, relative_path, base_name)
59+
else:
60+
return os.path.join(data_dir, base_name)
61+
5262

5363

5464
def main(args):
5565
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
5666
if args.bucket_reso_steps % 8 > 0:
5767
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
5868

59-
image_paths = train_util.glob_images(args.train_data_dir)
69+
train_data_dir_path = Path(args.train_data_dir)
70+
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
6071
print(f"found {len(image_paths)} images.")
6172

6273
if os.path.exists(args.in_json):
@@ -99,20 +110,20 @@ def process_batch(is_last):
99110
f"latent shape {latents.shape}, {bucket[0][1].shape}"
100111

101112
for (image_key, _), latent in zip(bucket, latents):
102-
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
113+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
103114
np.savez(npz_file_name, latent)
104115

105116
# flip
106117
if args.flip_aug:
107118
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
108119

109120
for (image_key, _), latent in zip(bucket, latents):
110-
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
121+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive)
111122
np.savez(npz_file_name, latent)
112123
else:
113124
# remove existing flipped npz
114125
for image_key, _ in bucket:
115-
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
126+
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
116127
if os.path.isfile(npz_file_name):
117128
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
118129
os.remove(npz_file_name)
@@ -169,9 +180,9 @@ def process_batch(is_last):
169180

170181
# 既に存在するファイルがあればshapeを確認して同じならskipする
171182
if args.skip_existing:
172-
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
183+
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
173184
if args.flip_aug:
174-
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
185+
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz")
175186

176187
found = True
177188
for npz_file in npz_files:
@@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser:
256267
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
257268
parser.add_argument("--skip_existing", action="store_true",
258269
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
270+
parser.add_argument("--recursive", action="store_true",
271+
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
259272

260273
return parser
261274

0 commit comments

Comments
 (0)