Skip to content

Commit 334589a

Browse files
authored
Merge pull request #424 from kohya-ss/dev
recursive support for finetune scripts
2 parents 6d5f847 + 43ef635 commit 334589a

6 files changed

+953
-715
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
127127

128128
## Change History
129129

130+
### 17 Apr. 2023, 2023/4/17:
131+
132+
- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
133+
- `finetune`フォルダ内の各スクリプトに再起的にフォルダを処理するオプション`--recursive`を追加しました。詳細は [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) を参照してください。Linaqruf 氏に感謝します。
134+
130135
### Naming of LoRA
131136

132137
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.

finetune/make_captions.py

+160-130
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import random
66

7+
from pathlib import Path
78
from PIL import Image
89
from tqdm import tqdm
910
import numpy as np
@@ -13,156 +14,185 @@
1314
from blip.blip import blip_decoder
1415
import library.train_util as train_util
1516

16-
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1718

1819

1920
IMAGE_SIZE = 384
2021

2122
# 正方形でいいのか? という気がするがソースがそうなので
22-
IMAGE_TRANSFORM = transforms.Compose([
23-
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
24-
transforms.ToTensor(),
25-
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
26-
])
23+
IMAGE_TRANSFORM = transforms.Compose(
24+
[
25+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
26+
transforms.ToTensor(),
27+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
28+
]
29+
)
30+
2731

2832
# 共通化したいが微妙に処理が異なる……
2933
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
30-
def __init__(self, image_paths):
31-
self.images = image_paths
34+
def __init__(self, image_paths):
35+
self.images = image_paths
3236

33-
def __len__(self):
34-
return len(self.images)
37+
def __len__(self):
38+
return len(self.images)
3539

36-
def __getitem__(self, idx):
37-
img_path = self.images[idx]
40+
def __getitem__(self, idx):
41+
img_path = self.images[idx]
3842

39-
try:
40-
image = Image.open(img_path).convert("RGB")
41-
# convert to tensor temporarily so dataloader will accept it
42-
tensor = IMAGE_TRANSFORM(image)
43-
except Exception as e:
44-
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
45-
return None
43+
try:
44+
image = Image.open(img_path).convert("RGB")
45+
# convert to tensor temporarily so dataloader will accept it
46+
tensor = IMAGE_TRANSFORM(image)
47+
except Exception as e:
48+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
49+
return None
4650

47-
return (tensor, img_path)
51+
return (tensor, img_path)
4852

4953

5054
def collate_fn_remove_corrupted(batch):
51-
"""Collate function that allows to remove corrupted examples in the
52-
dataloader. It expects that the dataloader returns 'None' when that occurs.
53-
The 'None's in the batch are removed.
54-
"""
55-
# Filter out all the Nones (corrupted examples)
56-
batch = list(filter(lambda x: x is not None, batch))
57-
return batch
55+
"""Collate function that allows to remove corrupted examples in the
56+
dataloader. It expects that the dataloader returns 'None' when that occurs.
57+
The 'None's in the batch are removed.
58+
"""
59+
# Filter out all the Nones (corrupted examples)
60+
batch = list(filter(lambda x: x is not None, batch))
61+
return batch
5862

5963

6064
def main(args):
61-
# fix the seed for reproducibility
62-
seed = args.seed # + utils.get_rank()
63-
torch.manual_seed(seed)
64-
np.random.seed(seed)
65-
random.seed(seed)
66-
67-
if not os.path.exists("blip"):
68-
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
69-
70-
cwd = os.getcwd()
71-
print('Current Working Directory is: ', cwd)
72-
os.chdir('finetune')
73-
74-
print(f"load images from {args.train_data_dir}")
75-
image_paths = train_util.glob_images(args.train_data_dir)
76-
print(f"found {len(image_paths)} images.")
77-
78-
print(f"loading BLIP caption: {args.caption_weights}")
79-
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
80-
model.eval()
81-
model = model.to(DEVICE)
82-
print("BLIP loaded")
83-
84-
# captioningする
85-
def run_batch(path_imgs):
86-
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
87-
88-
with torch.no_grad():
89-
if args.beam_search:
90-
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
91-
max_length=args.max_length, min_length=args.min_length)
92-
else:
93-
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
94-
95-
for (image_path, _), caption in zip(path_imgs, captions):
96-
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
97-
f.write(caption + "\n")
98-
if args.debug:
99-
print(image_path, caption)
100-
101-
# 読み込みの高速化のためにDataLoaderを使うオプション
102-
if args.max_data_loader_n_workers is not None:
103-
dataset = ImageLoadingTransformDataset(image_paths)
104-
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
105-
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
106-
else:
107-
data = [[(None, ip)] for ip in image_paths]
108-
109-
b_imgs = []
110-
for data_entry in tqdm(data, smoothing=0.0):
111-
for data in data_entry:
112-
if data is None:
113-
continue
114-
115-
img_tensor, image_path = data
116-
if img_tensor is None:
117-
try:
118-
raw_image = Image.open(image_path)
119-
if raw_image.mode != 'RGB':
120-
raw_image = raw_image.convert("RGB")
121-
img_tensor = IMAGE_TRANSFORM(raw_image)
122-
except Exception as e:
123-
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
124-
continue
125-
126-
b_imgs.append((image_path, img_tensor))
127-
if len(b_imgs) >= args.batch_size:
65+
# fix the seed for reproducibility
66+
seed = args.seed # + utils.get_rank()
67+
torch.manual_seed(seed)
68+
np.random.seed(seed)
69+
random.seed(seed)
70+
71+
if not os.path.exists("blip"):
72+
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
73+
74+
cwd = os.getcwd()
75+
print("Current Working Directory is: ", cwd)
76+
os.chdir("finetune")
77+
78+
print(f"load images from {args.train_data_dir}")
79+
train_data_dir_path = Path(args.train_data_dir)
80+
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
81+
print(f"found {len(image_paths)} images.")
82+
83+
print(f"loading BLIP caption: {args.caption_weights}")
84+
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
85+
model.eval()
86+
model = model.to(DEVICE)
87+
print("BLIP loaded")
88+
89+
# captioningする
90+
def run_batch(path_imgs):
91+
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
92+
93+
with torch.no_grad():
94+
if args.beam_search:
95+
captions = model.generate(
96+
imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length
97+
)
98+
else:
99+
captions = model.generate(
100+
imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length
101+
)
102+
103+
for (image_path, _), caption in zip(path_imgs, captions):
104+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
105+
f.write(caption + "\n")
106+
if args.debug:
107+
print(image_path, caption)
108+
109+
# 読み込みの高速化のためにDataLoaderを使うオプション
110+
if args.max_data_loader_n_workers is not None:
111+
dataset = ImageLoadingTransformDataset(image_paths)
112+
data = torch.utils.data.DataLoader(
113+
dataset,
114+
batch_size=args.batch_size,
115+
shuffle=False,
116+
num_workers=args.max_data_loader_n_workers,
117+
collate_fn=collate_fn_remove_corrupted,
118+
drop_last=False,
119+
)
120+
else:
121+
data = [[(None, ip)] for ip in image_paths]
122+
123+
b_imgs = []
124+
for data_entry in tqdm(data, smoothing=0.0):
125+
for data in data_entry:
126+
if data is None:
127+
continue
128+
129+
img_tensor, image_path = data
130+
if img_tensor is None:
131+
try:
132+
raw_image = Image.open(image_path)
133+
if raw_image.mode != "RGB":
134+
raw_image = raw_image.convert("RGB")
135+
img_tensor = IMAGE_TRANSFORM(raw_image)
136+
except Exception as e:
137+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
138+
continue
139+
140+
b_imgs.append((image_path, img_tensor))
141+
if len(b_imgs) >= args.batch_size:
142+
run_batch(b_imgs)
143+
b_imgs.clear()
144+
if len(b_imgs) > 0:
128145
run_batch(b_imgs)
129-
b_imgs.clear()
130-
if len(b_imgs) > 0:
131-
run_batch(b_imgs)
132146

133-
print("done!")
147+
print("done!")
134148

135149

136150
def setup_parser() -> argparse.ArgumentParser:
137-
parser = argparse.ArgumentParser()
138-
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
139-
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
140-
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
141-
parser.add_argument("--caption_extention", type=str, default=None,
142-
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
143-
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
144-
parser.add_argument("--beam_search", action="store_true",
145-
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
146-
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
147-
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
148-
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
149-
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
150-
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
151-
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
152-
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
153-
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
154-
parser.add_argument("--debug", action="store_true", help="debug mode")
155-
156-
return parser
157-
158-
159-
if __name__ == '__main__':
160-
parser = setup_parser()
161-
162-
args = parser.parse_args()
163-
164-
# スペルミスしていたオプションを復元する
165-
if args.caption_extention is not None:
166-
args.caption_extension = args.caption_extention
167-
168-
main(args)
151+
parser = argparse.ArgumentParser()
152+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
153+
parser.add_argument(
154+
"--caption_weights",
155+
type=str,
156+
default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
157+
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)",
158+
)
159+
parser.add_argument(
160+
"--caption_extention",
161+
type=str,
162+
default=None,
163+
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
164+
)
165+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
166+
parser.add_argument(
167+
"--beam_search",
168+
action="store_true",
169+
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)",
170+
)
171+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
172+
parser.add_argument(
173+
"--max_data_loader_n_workers",
174+
type=int,
175+
default=None,
176+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
177+
)
178+
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
179+
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
180+
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
181+
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
182+
parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed")
183+
parser.add_argument("--debug", action="store_true", help="debug mode")
184+
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
185+
186+
return parser
187+
188+
189+
if __name__ == "__main__":
190+
parser = setup_parser()
191+
192+
args = parser.parse_args()
193+
194+
# スペルミスしていたオプションを復元する
195+
if args.caption_extention is not None:
196+
args.caption_extension = args.caption_extention
197+
198+
main(args)

0 commit comments

Comments
 (0)