|
4 | 4 | import json
|
5 | 5 | import random
|
6 | 6 |
|
| 7 | +from pathlib import Path |
7 | 8 | from PIL import Image
|
8 | 9 | from tqdm import tqdm
|
9 | 10 | import numpy as np
|
|
13 | 14 | from blip.blip import blip_decoder
|
14 | 15 | import library.train_util as train_util
|
15 | 16 |
|
16 |
| -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 17 | +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
17 | 18 |
|
18 | 19 |
|
19 | 20 | IMAGE_SIZE = 384
|
20 | 21 |
|
21 | 22 | # 正方形でいいのか? という気がするがソースがそうなので
|
22 |
| -IMAGE_TRANSFORM = transforms.Compose([ |
23 |
| - transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), |
24 |
| - transforms.ToTensor(), |
25 |
| - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
26 |
| -]) |
| 23 | +IMAGE_TRANSFORM = transforms.Compose( |
| 24 | + [ |
| 25 | + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), |
| 26 | + transforms.ToTensor(), |
| 27 | + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
| 28 | + ] |
| 29 | +) |
| 30 | + |
27 | 31 |
|
28 | 32 | # 共通化したいが微妙に処理が異なる……
|
29 | 33 | class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
30 |
| - def __init__(self, image_paths): |
31 |
| - self.images = image_paths |
| 34 | + def __init__(self, image_paths): |
| 35 | + self.images = image_paths |
32 | 36 |
|
33 |
| - def __len__(self): |
34 |
| - return len(self.images) |
| 37 | + def __len__(self): |
| 38 | + return len(self.images) |
35 | 39 |
|
36 |
| - def __getitem__(self, idx): |
37 |
| - img_path = self.images[idx] |
| 40 | + def __getitem__(self, idx): |
| 41 | + img_path = self.images[idx] |
38 | 42 |
|
39 |
| - try: |
40 |
| - image = Image.open(img_path).convert("RGB") |
41 |
| - # convert to tensor temporarily so dataloader will accept it |
42 |
| - tensor = IMAGE_TRANSFORM(image) |
43 |
| - except Exception as e: |
44 |
| - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") |
45 |
| - return None |
| 43 | + try: |
| 44 | + image = Image.open(img_path).convert("RGB") |
| 45 | + # convert to tensor temporarily so dataloader will accept it |
| 46 | + tensor = IMAGE_TRANSFORM(image) |
| 47 | + except Exception as e: |
| 48 | + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") |
| 49 | + return None |
46 | 50 |
|
47 |
| - return (tensor, img_path) |
| 51 | + return (tensor, img_path) |
48 | 52 |
|
49 | 53 |
|
50 | 54 | def collate_fn_remove_corrupted(batch):
|
51 |
| - """Collate function that allows to remove corrupted examples in the |
52 |
| - dataloader. It expects that the dataloader returns 'None' when that occurs. |
53 |
| - The 'None's in the batch are removed. |
54 |
| - """ |
55 |
| - # Filter out all the Nones (corrupted examples) |
56 |
| - batch = list(filter(lambda x: x is not None, batch)) |
57 |
| - return batch |
| 55 | + """Collate function that allows to remove corrupted examples in the |
| 56 | + dataloader. It expects that the dataloader returns 'None' when that occurs. |
| 57 | + The 'None's in the batch are removed. |
| 58 | + """ |
| 59 | + # Filter out all the Nones (corrupted examples) |
| 60 | + batch = list(filter(lambda x: x is not None, batch)) |
| 61 | + return batch |
58 | 62 |
|
59 | 63 |
|
60 | 64 | def main(args):
|
61 |
| - # fix the seed for reproducibility |
62 |
| - seed = args.seed # + utils.get_rank() |
63 |
| - torch.manual_seed(seed) |
64 |
| - np.random.seed(seed) |
65 |
| - random.seed(seed) |
66 |
| - |
67 |
| - if not os.path.exists("blip"): |
68 |
| - args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path |
69 |
| - |
70 |
| - cwd = os.getcwd() |
71 |
| - print('Current Working Directory is: ', cwd) |
72 |
| - os.chdir('finetune') |
73 |
| - |
74 |
| - print(f"load images from {args.train_data_dir}") |
75 |
| - image_paths = train_util.glob_images(args.train_data_dir) |
76 |
| - print(f"found {len(image_paths)} images.") |
77 |
| - |
78 |
| - print(f"loading BLIP caption: {args.caption_weights}") |
79 |
| - model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json") |
80 |
| - model.eval() |
81 |
| - model = model.to(DEVICE) |
82 |
| - print("BLIP loaded") |
83 |
| - |
84 |
| - # captioningする |
85 |
| - def run_batch(path_imgs): |
86 |
| - imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) |
87 |
| - |
88 |
| - with torch.no_grad(): |
89 |
| - if args.beam_search: |
90 |
| - captions = model.generate(imgs, sample=False, num_beams=args.num_beams, |
91 |
| - max_length=args.max_length, min_length=args.min_length) |
92 |
| - else: |
93 |
| - captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) |
94 |
| - |
95 |
| - for (image_path, _), caption in zip(path_imgs, captions): |
96 |
| - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: |
97 |
| - f.write(caption + "\n") |
98 |
| - if args.debug: |
99 |
| - print(image_path, caption) |
100 |
| - |
101 |
| - # 読み込みの高速化のためにDataLoaderを使うオプション |
102 |
| - if args.max_data_loader_n_workers is not None: |
103 |
| - dataset = ImageLoadingTransformDataset(image_paths) |
104 |
| - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, |
105 |
| - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) |
106 |
| - else: |
107 |
| - data = [[(None, ip)] for ip in image_paths] |
108 |
| - |
109 |
| - b_imgs = [] |
110 |
| - for data_entry in tqdm(data, smoothing=0.0): |
111 |
| - for data in data_entry: |
112 |
| - if data is None: |
113 |
| - continue |
114 |
| - |
115 |
| - img_tensor, image_path = data |
116 |
| - if img_tensor is None: |
117 |
| - try: |
118 |
| - raw_image = Image.open(image_path) |
119 |
| - if raw_image.mode != 'RGB': |
120 |
| - raw_image = raw_image.convert("RGB") |
121 |
| - img_tensor = IMAGE_TRANSFORM(raw_image) |
122 |
| - except Exception as e: |
123 |
| - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") |
124 |
| - continue |
125 |
| - |
126 |
| - b_imgs.append((image_path, img_tensor)) |
127 |
| - if len(b_imgs) >= args.batch_size: |
| 65 | + # fix the seed for reproducibility |
| 66 | + seed = args.seed # + utils.get_rank() |
| 67 | + torch.manual_seed(seed) |
| 68 | + np.random.seed(seed) |
| 69 | + random.seed(seed) |
| 70 | + |
| 71 | + if not os.path.exists("blip"): |
| 72 | + args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path |
| 73 | + |
| 74 | + cwd = os.getcwd() |
| 75 | + print("Current Working Directory is: ", cwd) |
| 76 | + os.chdir("finetune") |
| 77 | + |
| 78 | + print(f"load images from {args.train_data_dir}") |
| 79 | + train_data_dir_path = Path(args.train_data_dir) |
| 80 | + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) |
| 81 | + print(f"found {len(image_paths)} images.") |
| 82 | + |
| 83 | + print(f"loading BLIP caption: {args.caption_weights}") |
| 84 | + model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") |
| 85 | + model.eval() |
| 86 | + model = model.to(DEVICE) |
| 87 | + print("BLIP loaded") |
| 88 | + |
| 89 | + # captioningする |
| 90 | + def run_batch(path_imgs): |
| 91 | + imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) |
| 92 | + |
| 93 | + with torch.no_grad(): |
| 94 | + if args.beam_search: |
| 95 | + captions = model.generate( |
| 96 | + imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length |
| 97 | + ) |
| 98 | + else: |
| 99 | + captions = model.generate( |
| 100 | + imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length |
| 101 | + ) |
| 102 | + |
| 103 | + for (image_path, _), caption in zip(path_imgs, captions): |
| 104 | + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: |
| 105 | + f.write(caption + "\n") |
| 106 | + if args.debug: |
| 107 | + print(image_path, caption) |
| 108 | + |
| 109 | + # 読み込みの高速化のためにDataLoaderを使うオプション |
| 110 | + if args.max_data_loader_n_workers is not None: |
| 111 | + dataset = ImageLoadingTransformDataset(image_paths) |
| 112 | + data = torch.utils.data.DataLoader( |
| 113 | + dataset, |
| 114 | + batch_size=args.batch_size, |
| 115 | + shuffle=False, |
| 116 | + num_workers=args.max_data_loader_n_workers, |
| 117 | + collate_fn=collate_fn_remove_corrupted, |
| 118 | + drop_last=False, |
| 119 | + ) |
| 120 | + else: |
| 121 | + data = [[(None, ip)] for ip in image_paths] |
| 122 | + |
| 123 | + b_imgs = [] |
| 124 | + for data_entry in tqdm(data, smoothing=0.0): |
| 125 | + for data in data_entry: |
| 126 | + if data is None: |
| 127 | + continue |
| 128 | + |
| 129 | + img_tensor, image_path = data |
| 130 | + if img_tensor is None: |
| 131 | + try: |
| 132 | + raw_image = Image.open(image_path) |
| 133 | + if raw_image.mode != "RGB": |
| 134 | + raw_image = raw_image.convert("RGB") |
| 135 | + img_tensor = IMAGE_TRANSFORM(raw_image) |
| 136 | + except Exception as e: |
| 137 | + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") |
| 138 | + continue |
| 139 | + |
| 140 | + b_imgs.append((image_path, img_tensor)) |
| 141 | + if len(b_imgs) >= args.batch_size: |
| 142 | + run_batch(b_imgs) |
| 143 | + b_imgs.clear() |
| 144 | + if len(b_imgs) > 0: |
128 | 145 | run_batch(b_imgs)
|
129 |
| - b_imgs.clear() |
130 |
| - if len(b_imgs) > 0: |
131 |
| - run_batch(b_imgs) |
132 | 146 |
|
133 |
| - print("done!") |
| 147 | + print("done!") |
134 | 148 |
|
135 | 149 |
|
136 | 150 | def setup_parser() -> argparse.ArgumentParser:
|
137 |
| - parser = argparse.ArgumentParser() |
138 |
| - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") |
139 |
| - parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", |
140 |
| - help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") |
141 |
| - parser.add_argument("--caption_extention", type=str, default=None, |
142 |
| - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") |
143 |
| - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") |
144 |
| - parser.add_argument("--beam_search", action="store_true", |
145 |
| - help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") |
146 |
| - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") |
147 |
| - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, |
148 |
| - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") |
149 |
| - parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") |
150 |
| - parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") |
151 |
| - parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") |
152 |
| - parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") |
153 |
| - parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') |
154 |
| - parser.add_argument("--debug", action="store_true", help="debug mode") |
155 |
| - |
156 |
| - return parser |
157 |
| - |
158 |
| - |
159 |
| -if __name__ == '__main__': |
160 |
| - parser = setup_parser() |
161 |
| - |
162 |
| - args = parser.parse_args() |
163 |
| - |
164 |
| - # スペルミスしていたオプションを復元する |
165 |
| - if args.caption_extention is not None: |
166 |
| - args.caption_extension = args.caption_extention |
167 |
| - |
168 |
| - main(args) |
| 151 | + parser = argparse.ArgumentParser() |
| 152 | + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") |
| 153 | + parser.add_argument( |
| 154 | + "--caption_weights", |
| 155 | + type=str, |
| 156 | + default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", |
| 157 | + help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)", |
| 158 | + ) |
| 159 | + parser.add_argument( |
| 160 | + "--caption_extention", |
| 161 | + type=str, |
| 162 | + default=None, |
| 163 | + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", |
| 164 | + ) |
| 165 | + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") |
| 166 | + parser.add_argument( |
| 167 | + "--beam_search", |
| 168 | + action="store_true", |
| 169 | + help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)", |
| 170 | + ) |
| 171 | + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") |
| 172 | + parser.add_argument( |
| 173 | + "--max_data_loader_n_workers", |
| 174 | + type=int, |
| 175 | + default=None, |
| 176 | + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", |
| 177 | + ) |
| 178 | + parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") |
| 179 | + parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") |
| 180 | + parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") |
| 181 | + parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") |
| 182 | + parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed") |
| 183 | + parser.add_argument("--debug", action="store_true", help="debug mode") |
| 184 | + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") |
| 185 | + |
| 186 | + return parser |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == "__main__": |
| 190 | + parser = setup_parser() |
| 191 | + |
| 192 | + args = parser.parse_args() |
| 193 | + |
| 194 | + # スペルミスしていたオプションを復元する |
| 195 | + if args.caption_extention is not None: |
| 196 | + args.caption_extension = args.caption_extention |
| 197 | + |
| 198 | + main(args) |
0 commit comments