|
| 1 | +import argparse |
| 2 | +import csv |
| 3 | +import glob |
| 4 | +import os |
| 5 | + |
| 6 | +from PIL import Image |
| 7 | +import cv2 |
| 8 | +from tqdm import tqdm |
| 9 | +import numpy as np |
| 10 | +from tensorflow.keras.models import load_model |
| 11 | +from huggingface_hub import hf_hub_download |
| 12 | +import torch |
| 13 | + |
| 14 | +# import library.train_util as train_util |
| 15 | + |
| 16 | +# from wd14 tagger |
| 17 | +IMAGE_SIZE = 448 |
| 18 | +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] |
| 19 | + |
| 20 | +# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 |
| 21 | +DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' |
| 22 | +FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] |
| 23 | +SUB_DIR = "variables" |
| 24 | +SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] |
| 25 | +CSV_FILE = FILES[-1] |
| 26 | + |
| 27 | +def glob_images(directory, base="*"): |
| 28 | + img_paths = [] |
| 29 | + for ext in IMAGE_EXTENSIONS: |
| 30 | + if base == "*": |
| 31 | + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
| 32 | + else: |
| 33 | + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
| 34 | + img_paths = list(set(img_paths)) # 重複を排除 |
| 35 | + img_paths.sort() |
| 36 | + return img_paths |
| 37 | + |
| 38 | +def preprocess_image(image): |
| 39 | + image = np.array(image) |
| 40 | + image = image[:, :, ::-1] # RGB->BGR |
| 41 | + |
| 42 | + # pad to square |
| 43 | + size = max(image.shape[0:2]) |
| 44 | + pad_x = size - image.shape[1] |
| 45 | + pad_y = size - image.shape[0] |
| 46 | + pad_l = pad_x // 2 |
| 47 | + pad_t = pad_y // 2 |
| 48 | + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) |
| 49 | + |
| 50 | + interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 |
| 51 | + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) |
| 52 | + |
| 53 | + image = image.astype(np.float32) |
| 54 | + return image |
| 55 | + |
| 56 | + |
| 57 | +class ImageLoadingPrepDataset(torch.utils.data.Dataset): |
| 58 | + def __init__(self, image_paths): |
| 59 | + self.images = image_paths |
| 60 | + |
| 61 | + def __len__(self): |
| 62 | + return len(self.images) |
| 63 | + |
| 64 | + def __getitem__(self, idx): |
| 65 | + img_path = self.images[idx] |
| 66 | + |
| 67 | + try: |
| 68 | + image = Image.open(img_path).convert("RGB") |
| 69 | + image = preprocess_image(image) |
| 70 | + tensor = torch.tensor(image) |
| 71 | + except Exception as e: |
| 72 | + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") |
| 73 | + return None |
| 74 | + |
| 75 | + return (tensor, img_path) |
| 76 | + |
| 77 | + |
| 78 | +def collate_fn_remove_corrupted(batch): |
| 79 | + """Collate function that allows to remove corrupted examples in the |
| 80 | + dataloader. It expects that the dataloader returns 'None' when that occurs. |
| 81 | + The 'None's in the batch are removed. |
| 82 | + """ |
| 83 | + # Filter out all the Nones (corrupted examples) |
| 84 | + batch = list(filter(lambda x: x is not None, batch)) |
| 85 | + return batch |
| 86 | + |
| 87 | + |
| 88 | +def main(args): |
| 89 | + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする |
| 90 | + # depreacatedの警告が出るけどなくなったらその時 |
| 91 | + # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 |
| 92 | + if not os.path.exists(args.model_dir) or args.force_download: |
| 93 | + print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") |
| 94 | + for file in FILES: |
| 95 | + hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) |
| 96 | + for file in SUB_DIR_FILES: |
| 97 | + hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( |
| 98 | + args.model_dir, SUB_DIR), force_download=True, force_filename=file) |
| 99 | + else: |
| 100 | + print("using existing wd14 tagger model") |
| 101 | + |
| 102 | + # 画像を読み込む |
| 103 | + image_paths = glob_images(args.train_data_dir) |
| 104 | + print(f"found {len(image_paths)} images.") |
| 105 | + |
| 106 | + print("loading model and labels") |
| 107 | + model = load_model(args.model_dir) |
| 108 | + |
| 109 | + # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") |
| 110 | + # 依存ライブラリを増やしたくないので自力で読むよ |
| 111 | + with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: |
| 112 | + reader = csv.reader(f) |
| 113 | + l = [row for row in reader] |
| 114 | + header = l[0] # tag_id,name,category,count |
| 115 | + rows = l[1:] |
| 116 | + assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" |
| 117 | + |
| 118 | + tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ |
| 119 | + |
| 120 | + # 推論する |
| 121 | + def run_batch(path_imgs): |
| 122 | + imgs = np.array([im for _, im in path_imgs]) |
| 123 | + |
| 124 | + probs = model(imgs, training=False) |
| 125 | + probs = probs.numpy() |
| 126 | + |
| 127 | + for (image_path, _), prob in zip(path_imgs, probs): |
| 128 | + # 最初の4つはratingなので無視する |
| 129 | + # # First 4 labels are actually ratings: pick one with argmax |
| 130 | + # ratings_names = label_names[:4] |
| 131 | + # rating_index = ratings_names["probs"].argmax() |
| 132 | + # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] |
| 133 | + |
| 134 | + # それ以降はタグなのでconfidenceがthresholdより高いものを追加する |
| 135 | + # Everything else is tags: pick any where prediction confidence > threshold |
| 136 | + tag_text = "" |
| 137 | + for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで |
| 138 | + if p >= args.thresh and i < len(tags): |
| 139 | + tag_text += ", " + tags[i] |
| 140 | + |
| 141 | + if len(tag_text) > 0: |
| 142 | + tag_text = tag_text[2:] # 最初の ", " を消す |
| 143 | + |
| 144 | + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: |
| 145 | + f.write(tag_text + '\n') |
| 146 | + if args.debug: |
| 147 | + print(image_path, tag_text) |
| 148 | + |
| 149 | + # 読み込みの高速化のためにDataLoaderを使うオプション |
| 150 | + if args.max_data_loader_n_workers is not None: |
| 151 | + dataset = ImageLoadingPrepDataset(image_paths) |
| 152 | + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, |
| 153 | + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) |
| 154 | + else: |
| 155 | + data = [[(None, ip)] for ip in image_paths] |
| 156 | + |
| 157 | + b_imgs = [] |
| 158 | + for data_entry in tqdm(data, smoothing=0.0): |
| 159 | + for data in data_entry: |
| 160 | + if data is None: |
| 161 | + continue |
| 162 | + |
| 163 | + image, image_path = data |
| 164 | + if image is not None: |
| 165 | + image = image.detach().numpy() |
| 166 | + else: |
| 167 | + try: |
| 168 | + image = Image.open(image_path) |
| 169 | + if image.mode != 'RGB': |
| 170 | + image = image.convert("RGB") |
| 171 | + image = preprocess_image(image) |
| 172 | + except Exception as e: |
| 173 | + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") |
| 174 | + continue |
| 175 | + b_imgs.append((image_path, image)) |
| 176 | + |
| 177 | + if len(b_imgs) >= args.batch_size: |
| 178 | + run_batch(b_imgs) |
| 179 | + b_imgs.clear() |
| 180 | + |
| 181 | + if len(b_imgs) > 0: |
| 182 | + run_batch(b_imgs) |
| 183 | + |
| 184 | + print("done!") |
| 185 | + |
| 186 | + |
| 187 | +def setup_parser() -> argparse.ArgumentParser: |
| 188 | + parser = argparse.ArgumentParser() |
| 189 | + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") |
| 190 | + parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, |
| 191 | + help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") |
| 192 | + parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", |
| 193 | + help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") |
| 194 | + parser.add_argument("--force_download", action='store_true', |
| 195 | + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") |
| 196 | + parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") |
| 197 | + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") |
| 198 | + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, |
| 199 | + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") |
| 200 | + parser.add_argument("--caption_extention", type=str, default=None, |
| 201 | + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") |
| 202 | + parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") |
| 203 | + parser.add_argument("--debug", action="store_true", help="debug mode") |
| 204 | + |
| 205 | + return parser |
| 206 | + |
| 207 | + |
| 208 | +if __name__ == '__main__': |
| 209 | + parser = setup_parser() |
| 210 | + |
| 211 | + args = parser.parse_args() |
| 212 | + |
| 213 | + # スペルミスしていたオプションを復元する |
| 214 | + if args.caption_extention is not None: |
| 215 | + args.caption_extension = args.caption_extention |
| 216 | + |
| 217 | + main(args) |
0 commit comments