11
11
from torchvision import transforms
12
12
from torchvision .transforms .functional import InterpolationMode
13
13
from blip .blip import blip_decoder
14
- # from Salesforce_BLIP.models.blip import blip_decoder
14
+ import library . train_util as train_util
15
15
16
16
DEVICE = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
17
17
18
18
19
+ IMAGE_SIZE = 384
20
+
21
+ # 正方形でいいのか? という気がするがソースがそうなので
22
+ IMAGE_TRANSFORM = transforms .Compose ([
23
+ transforms .Resize ((IMAGE_SIZE , IMAGE_SIZE ), interpolation = InterpolationMode .BICUBIC ),
24
+ transforms .ToTensor (),
25
+ transforms .Normalize ((0.48145466 , 0.4578275 , 0.40821073 ), (0.26862954 , 0.26130258 , 0.27577711 ))
26
+ ])
27
+
28
+ # 共通化したいが微妙に処理が異なる……
29
+ class ImageLoadingTransformDataset (torch .utils .data .Dataset ):
30
+ def __init__ (self , image_paths ):
31
+ self .images = image_paths
32
+
33
+ def __len__ (self ):
34
+ return len (self .images )
35
+
36
+ def __getitem__ (self , idx ):
37
+ img_path = self .images [idx ]
38
+
39
+ try :
40
+ image = Image .open (img_path ).convert ("RGB" )
41
+ # convert to tensor temporarily so dataloader will accept it
42
+ tensor = IMAGE_TRANSFORM (image )
43
+ except Exception as e :
44
+ print (f"Could not load image path / 画像を読み込めません: { img_path } , error: { e } " )
45
+ return None
46
+
47
+ return (tensor , img_path )
48
+
49
+
50
+ def collate_fn_remove_corrupted (batch ):
51
+ """Collate function that allows to remove corrupted examples in the
52
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
53
+ The 'None's in the batch are removed.
54
+ """
55
+ # Filter out all the Nones (corrupted examples)
56
+ batch = list (filter (lambda x : x is not None , batch ))
57
+ return batch
58
+
59
+
19
60
def main (args ):
20
61
# fix the seed for reproducibility
21
- seed = args .seed # + utils.get_rank()
62
+ seed = args .seed # + utils.get_rank()
22
63
torch .manual_seed (seed )
23
64
np .random .seed (seed )
24
65
random .seed (seed )
25
-
66
+
26
67
if not os .path .exists ("blip" ):
27
68
args .train_data_dir = os .path .abspath (args .train_data_dir ) # convert to absolute path
28
69
@@ -31,24 +72,15 @@ def main(args):
31
72
os .chdir ('finetune' )
32
73
33
74
print (f"load images from { args .train_data_dir } " )
34
- image_paths = glob .glob (os .path .join (args .train_data_dir , "*.jpg" )) + glob .glob (os .path .join (args .train_data_dir , "*.jpeg" )) + \
35
- glob .glob (os .path .join (args .train_data_dir , "*.png" )) + glob .glob (os .path .join (args .train_data_dir , "*.webp" ))
75
+ image_paths = train_util .glob_images (args .train_data_dir )
36
76
print (f"found { len (image_paths )} images." )
37
77
38
78
print (f"loading BLIP caption: { args .caption_weights } " )
39
- image_size = 384
40
- model = blip_decoder (pretrained = args .caption_weights , image_size = image_size , vit = 'large' , med_config = "./blip/med_config.json" )
79
+ model = blip_decoder (pretrained = args .caption_weights , image_size = IMAGE_SIZE , vit = 'large' , med_config = "./blip/med_config.json" )
41
80
model .eval ()
42
81
model = model .to (DEVICE )
43
82
print ("BLIP loaded" )
44
83
45
- # 正方形でいいのか? という気がするがソースがそうなので
46
- transform = transforms .Compose ([
47
- transforms .Resize ((image_size , image_size ), interpolation = InterpolationMode .BICUBIC ),
48
- transforms .ToTensor (),
49
- transforms .Normalize ((0.48145466 , 0.4578275 , 0.40821073 ), (0.26862954 , 0.26130258 , 0.27577711 ))
50
- ])
51
-
52
84
# captioningする
53
85
def run_batch (path_imgs ):
54
86
imgs = torch .stack ([im for _ , im in path_imgs ]).to (DEVICE )
@@ -66,18 +98,35 @@ def run_batch(path_imgs):
66
98
if args .debug :
67
99
print (image_path , caption )
68
100
101
+ # 読み込みの高速化のためにDataLoaderを使うオプション
102
+ if args .max_data_loader_n_workers is not None :
103
+ dataset = ImageLoadingTransformDataset (image_paths )
104
+ data = torch .utils .data .DataLoader (dataset , batch_size = args .batch_size , shuffle = False ,
105
+ num_workers = args .max_data_loader_n_workers , collate_fn = collate_fn_remove_corrupted , drop_last = False )
106
+ else :
107
+ data = [[(None , ip )] for ip in image_paths ]
108
+
69
109
b_imgs = []
70
- for image_path in tqdm (image_paths , smoothing = 0.0 ):
71
- raw_image = Image .open (image_path )
72
- if raw_image .mode != "RGB" :
73
- print (f"convert image mode { raw_image .mode } to RGB: { image_path } " )
74
- raw_image = raw_image .convert ("RGB" )
75
-
76
- image = transform (raw_image )
77
- b_imgs .append ((image_path , image ))
78
- if len (b_imgs ) >= args .batch_size :
79
- run_batch (b_imgs )
80
- b_imgs .clear ()
110
+ for data_entry in tqdm (data , smoothing = 0.0 ):
111
+ for data in data_entry :
112
+ if data is None :
113
+ continue
114
+
115
+ img_tensor , image_path = data
116
+ if img_tensor is None :
117
+ try :
118
+ raw_image = Image .open (image_path )
119
+ if raw_image .mode != 'RGB' :
120
+ raw_image = raw_image .convert ("RGB" )
121
+ img_tensor = IMAGE_TRANSFORM (raw_image )
122
+ except Exception as e :
123
+ print (f"Could not load image path / 画像を読み込めません: { image_path } , error: { e } " )
124
+ continue
125
+
126
+ b_imgs .append ((image_path , img_tensor ))
127
+ if len (b_imgs ) >= args .batch_size :
128
+ run_batch (b_imgs )
129
+ b_imgs .clear ()
81
130
if len (b_imgs ) > 0 :
82
131
run_batch (b_imgs )
83
132
@@ -95,6 +144,8 @@ def run_batch(path_imgs):
95
144
parser .add_argument ("--beam_search" , action = "store_true" ,
96
145
help = "use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)" )
97
146
parser .add_argument ("--batch_size" , type = int , default = 1 , help = "batch size in inference / 推論時のバッチサイズ" )
147
+ parser .add_argument ("--max_data_loader_n_workers" , type = int , default = None ,
148
+ help = "enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)" )
98
149
parser .add_argument ("--num_beams" , type = int , default = 1 , help = "num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)" )
99
150
parser .add_argument ("--top_p" , type = float , default = 0.9 , help = "top_p in Nucleus sampling / Nucleus sampling時のtop_p" )
100
151
parser .add_argument ("--max_length" , type = int , default = 75 , help = "max length of caption / captionの最大長" )
0 commit comments