Skip to content

Commit 57d8483

Browse files
committed
add GIT captioning, refactoring, DataLoader
1 parent 8c3a52e commit 57d8483

9 files changed

+481
-146
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ __pycache__
33
wd14_tagger_model
44
venv
55
*.egg-info
6-
build
6+
build
7+
.vscode

finetune/make_captions.py

+76-25
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,59 @@
1111
from torchvision import transforms
1212
from torchvision.transforms.functional import InterpolationMode
1313
from blip.blip import blip_decoder
14-
# from Salesforce_BLIP.models.blip import blip_decoder
14+
import library.train_util as train_util
1515

1616
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1717

1818

19+
IMAGE_SIZE = 384
20+
21+
# 正方形でいいのか? という気がするがソースがそうなので
22+
IMAGE_TRANSFORM = transforms.Compose([
23+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
24+
transforms.ToTensor(),
25+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
26+
])
27+
28+
# 共通化したいが微妙に処理が異なる……
29+
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
30+
def __init__(self, image_paths):
31+
self.images = image_paths
32+
33+
def __len__(self):
34+
return len(self.images)
35+
36+
def __getitem__(self, idx):
37+
img_path = self.images[idx]
38+
39+
try:
40+
image = Image.open(img_path).convert("RGB")
41+
# convert to tensor temporarily so dataloader will accept it
42+
tensor = IMAGE_TRANSFORM(image)
43+
except Exception as e:
44+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
45+
return None
46+
47+
return (tensor, img_path)
48+
49+
50+
def collate_fn_remove_corrupted(batch):
51+
"""Collate function that allows to remove corrupted examples in the
52+
dataloader. It expects that the dataloader returns 'None' when that occurs.
53+
The 'None's in the batch are removed.
54+
"""
55+
# Filter out all the Nones (corrupted examples)
56+
batch = list(filter(lambda x: x is not None, batch))
57+
return batch
58+
59+
1960
def main(args):
2061
# fix the seed for reproducibility
21-
seed = args.seed # + utils.get_rank()
62+
seed = args.seed # + utils.get_rank()
2263
torch.manual_seed(seed)
2364
np.random.seed(seed)
2465
random.seed(seed)
25-
66+
2667
if not os.path.exists("blip"):
2768
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
2869

@@ -31,24 +72,15 @@ def main(args):
3172
os.chdir('finetune')
3273

3374
print(f"load images from {args.train_data_dir}")
34-
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
35-
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
75+
image_paths = train_util.glob_images(args.train_data_dir)
3676
print(f"found {len(image_paths)} images.")
3777

3878
print(f"loading BLIP caption: {args.caption_weights}")
39-
image_size = 384
40-
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
79+
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
4180
model.eval()
4281
model = model.to(DEVICE)
4382
print("BLIP loaded")
4483

45-
# 正方形でいいのか? という気がするがソースがそうなので
46-
transform = transforms.Compose([
47-
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
48-
transforms.ToTensor(),
49-
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
50-
])
51-
5284
# captioningする
5385
def run_batch(path_imgs):
5486
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
@@ -66,18 +98,35 @@ def run_batch(path_imgs):
6698
if args.debug:
6799
print(image_path, caption)
68100

101+
# 読み込みの高速化のためにDataLoaderを使うオプション
102+
if args.max_data_loader_n_workers is not None:
103+
dataset = ImageLoadingTransformDataset(image_paths)
104+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
105+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
106+
else:
107+
data = [[(None, ip)] for ip in image_paths]
108+
69109
b_imgs = []
70-
for image_path in tqdm(image_paths, smoothing=0.0):
71-
raw_image = Image.open(image_path)
72-
if raw_image.mode != "RGB":
73-
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
74-
raw_image = raw_image.convert("RGB")
75-
76-
image = transform(raw_image)
77-
b_imgs.append((image_path, image))
78-
if len(b_imgs) >= args.batch_size:
79-
run_batch(b_imgs)
80-
b_imgs.clear()
110+
for data_entry in tqdm(data, smoothing=0.0):
111+
for data in data_entry:
112+
if data is None:
113+
continue
114+
115+
img_tensor, image_path = data
116+
if img_tensor is None:
117+
try:
118+
raw_image = Image.open(image_path)
119+
if raw_image.mode != 'RGB':
120+
raw_image = raw_image.convert("RGB")
121+
img_tensor = IMAGE_TRANSFORM(raw_image)
122+
except Exception as e:
123+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
124+
continue
125+
126+
b_imgs.append((image_path, img_tensor))
127+
if len(b_imgs) >= args.batch_size:
128+
run_batch(b_imgs)
129+
b_imgs.clear()
81130
if len(b_imgs) > 0:
82131
run_batch(b_imgs)
83132

@@ -95,6 +144,8 @@ def run_batch(path_imgs):
95144
parser.add_argument("--beam_search", action="store_true",
96145
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
97146
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
147+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
148+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
98149
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
99150
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
100151
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")

finetune/make_captions_by_git.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import argparse
2+
import os
3+
import re
4+
5+
from PIL import Image
6+
from tqdm import tqdm
7+
import torch
8+
from transformers import AutoProcessor, AutoModelForCausalLM
9+
from transformers.generation.utils import GenerationMixin
10+
11+
import library.train_util as train_util
12+
13+
14+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15+
16+
PATTERN_REPLACE = [re.compile(r'with the (words?|letters?) (" ?[^"]*"|\w+)( on (the)? ?\w+)?'),
17+
re.compile(r'that says (" ?[^"]*"|\w+)')]
18+
19+
20+
# 誤検知しまくりの with the word xxxx を消す
21+
def remove_words(captions, debug):
22+
removed_caps = []
23+
for caption in captions:
24+
cap = caption
25+
for pat in PATTERN_REPLACE:
26+
cap = pat.sub("", caption)
27+
if debug and cap != caption:
28+
print(caption)
29+
print(cap)
30+
removed_caps.append(cap)
31+
return removed_caps
32+
33+
34+
def collate_fn_remove_corrupted(batch):
35+
"""Collate function that allows to remove corrupted examples in the
36+
dataloader. It expects that the dataloader returns 'None' when that occurs.
37+
The 'None's in the batch are removed.
38+
"""
39+
# Filter out all the Nones (corrupted examples)
40+
batch = list(filter(lambda x: x is not None, batch))
41+
return batch
42+
43+
44+
def main(args):
45+
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
46+
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
47+
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
48+
49+
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
50+
# ここより上で置き換えようとするとすごく大変
51+
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
52+
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
53+
if input_ids.size()[0] != curr_batch_size[0]:
54+
input_ids = input_ids.repeat(curr_batch_size[0], 1)
55+
return input_ids
56+
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
57+
58+
print(f"load images from {args.train_data_dir}")
59+
image_paths = train_util.glob_images(args.train_data_dir)
60+
print(f"found {len(image_paths)} images.")
61+
62+
# できればcacheに依存せず明示的にダウンロードしたい
63+
print(f"loading GIT: {args.model_id}")
64+
git_processor = AutoProcessor.from_pretrained(args.model_id)
65+
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
66+
print("GIT loaded")
67+
68+
# captioningする
69+
def run_batch(path_imgs):
70+
imgs = [im for _, im in path_imgs]
71+
72+
curr_batch_size[0] = len(path_imgs)
73+
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
74+
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
75+
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
76+
77+
if args.remove_words:
78+
captions = remove_words(captions, args.debug)
79+
80+
for (image_path, _), caption in zip(path_imgs, captions):
81+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
82+
f.write(caption + "\n")
83+
if args.debug:
84+
print(image_path, caption)
85+
86+
# 読み込みの高速化のためにDataLoaderを使うオプション
87+
if args.max_data_loader_n_workers is not None:
88+
dataset = train_util.ImageLoadingDataset(image_paths)
89+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
90+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
91+
else:
92+
data = [[(None, ip)] for ip in image_paths]
93+
94+
b_imgs = []
95+
for data_entry in tqdm(data, smoothing=0.0):
96+
for data in data_entry:
97+
if data is None:
98+
continue
99+
100+
image, image_path = data
101+
if image is None:
102+
try:
103+
image = Image.open(image_path)
104+
if image.mode != 'RGB':
105+
image = image.convert("RGB")
106+
except Exception as e:
107+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
108+
continue
109+
110+
b_imgs.append((image_path, image))
111+
if len(b_imgs) >= args.batch_size:
112+
run_batch(b_imgs)
113+
b_imgs.clear()
114+
115+
if len(b_imgs) > 0:
116+
run_batch(b_imgs)
117+
118+
print("done!")
119+
120+
121+
if __name__ == '__main__':
122+
parser = argparse.ArgumentParser()
123+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
124+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
125+
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
126+
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
127+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
128+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
129+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
130+
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
131+
parser.add_argument("--remove_words", action="store_true",
132+
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
133+
parser.add_argument("--debug", action="store_true", help="debug mode")
134+
135+
args = parser.parse_args()
136+
main(args)

finetune/merge_captions_to_metadata.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,35 @@
1-
# このスクリプトのライセンスは、Apache License 2.0とします
2-
# (c) 2022 Kohya S. @kohya_ss
3-
41
import argparse
5-
import glob
6-
import os
72
import json
8-
3+
from pathlib import Path
4+
from typing import List
95
from tqdm import tqdm
6+
import library.train_util as train_util
107

118

129
def main(args):
13-
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
14-
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
10+
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
11+
12+
train_data_dir_path = Path(args.train_data_dir)
13+
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
1514
print(f"found {len(image_paths)} images.")
1615

17-
if args.in_json is None and os.path.isfile(args.out_json):
16+
if args.in_json is None and Path(args.out_json).is_file():
1817
args.in_json = args.out_json
1918

2019
if args.in_json is not None:
2120
print(f"loading existing metadata: {args.in_json}")
22-
with open(args.in_json, "rt", encoding='utf-8') as f:
23-
metadata = json.load(f)
21+
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
2422
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
2523
else:
2624
print("new metadata will be created / 新しいメタデータファイルが作成されます")
2725
metadata = {}
2826

2927
print("merge caption texts to metadata json.")
3028
for image_path in tqdm(image_paths):
31-
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
32-
with open(caption_path, "rt", encoding='utf-8') as f:
33-
lines = f.readlines()
34-
caption = lines[0].strip() if len(lines) > 0 else ""
29+
caption_path = image_path.with_suffix(args.caption_extension)
30+
caption = caption_path.read_text(encoding='utf-8').strip()
3531

36-
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
32+
image_key = str(image_path) if args.full_path else image_path.stem
3733
if image_key not in metadata:
3834
metadata[image_key] = {}
3935

@@ -43,21 +39,23 @@ def main(args):
4339

4440
# metadataを書き出して終わり
4541
print(f"writing metadata: {args.out_json}")
46-
with open(args.out_json, "wt", encoding='utf-8') as f:
47-
json.dump(metadata, f, indent=2)
42+
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
4843
print("done!")
4944

5045

5146
if __name__ == '__main__':
5247
parser = argparse.ArgumentParser()
5348
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
5449
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
55-
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
50+
parser.add_argument("--in_json", type=str,
51+
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
5652
parser.add_argument("--caption_extention", type=str, default=None,
5753
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
5854
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
5955
parser.add_argument("--full_path", action="store_true",
6056
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
57+
parser.add_argument("--recursive", action="store_true",
58+
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
6159
parser.add_argument("--debug", action="store_true", help="debug mode")
6260

6361
args = parser.parse_args()

0 commit comments

Comments
 (0)