Skip to content

Commit d37aa6e

Browse files
committed
v21.3.9
1 parent 1e24e29 commit d37aa6e

6 files changed

+228
-4
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ wd14_tagger_model
88
.DS_Store
99
locon
1010
gui-user.bat
11-
gui-user.ps1
11+
gui-user.ps1
12+
library/__init__.py

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ This will store your a backup file with your current locally installed pip packa
192192

193193
## Change History
194194

195+
* 2023/04/01 (v21.3.9)
196+
- Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496
195197
* 2023/03/30 (v21.3.8)
196198
- Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481
197199
* 2023/03/29 (v21.3.7)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import argparse
2+
import csv
3+
import glob
4+
import os
5+
6+
from PIL import Image
7+
import cv2
8+
from tqdm import tqdm
9+
import numpy as np
10+
from tensorflow.keras.models import load_model
11+
from huggingface_hub import hf_hub_download
12+
import torch
13+
14+
# import library.train_util as train_util
15+
16+
# from wd14 tagger
17+
IMAGE_SIZE = 448
18+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
19+
20+
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
21+
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
22+
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
23+
SUB_DIR = "variables"
24+
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
25+
CSV_FILE = FILES[-1]
26+
27+
def glob_images(directory, base="*"):
28+
img_paths = []
29+
for ext in IMAGE_EXTENSIONS:
30+
if base == "*":
31+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
32+
else:
33+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
34+
img_paths = list(set(img_paths)) # 重複を排除
35+
img_paths.sort()
36+
return img_paths
37+
38+
def preprocess_image(image):
39+
image = np.array(image)
40+
image = image[:, :, ::-1] # RGB->BGR
41+
42+
# pad to square
43+
size = max(image.shape[0:2])
44+
pad_x = size - image.shape[1]
45+
pad_y = size - image.shape[0]
46+
pad_l = pad_x // 2
47+
pad_t = pad_y // 2
48+
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
49+
50+
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
51+
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
52+
53+
image = image.astype(np.float32)
54+
return image
55+
56+
57+
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
58+
def __init__(self, image_paths):
59+
self.images = image_paths
60+
61+
def __len__(self):
62+
return len(self.images)
63+
64+
def __getitem__(self, idx):
65+
img_path = self.images[idx]
66+
67+
try:
68+
image = Image.open(img_path).convert("RGB")
69+
image = preprocess_image(image)
70+
tensor = torch.tensor(image)
71+
except Exception as e:
72+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
73+
return None
74+
75+
return (tensor, img_path)
76+
77+
78+
def collate_fn_remove_corrupted(batch):
79+
"""Collate function that allows to remove corrupted examples in the
80+
dataloader. It expects that the dataloader returns 'None' when that occurs.
81+
The 'None's in the batch are removed.
82+
"""
83+
# Filter out all the Nones (corrupted examples)
84+
batch = list(filter(lambda x: x is not None, batch))
85+
return batch
86+
87+
88+
def main(args):
89+
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
90+
# depreacatedの警告が出るけどなくなったらその時
91+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
92+
if not os.path.exists(args.model_dir) or args.force_download:
93+
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
94+
for file in FILES:
95+
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
96+
for file in SUB_DIR_FILES:
97+
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
98+
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
99+
else:
100+
print("using existing wd14 tagger model")
101+
102+
# 画像を読み込む
103+
image_paths = glob_images(args.train_data_dir)
104+
print(f"found {len(image_paths)} images.")
105+
106+
print("loading model and labels")
107+
model = load_model(args.model_dir)
108+
109+
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
110+
# 依存ライブラリを増やしたくないので自力で読むよ
111+
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
112+
reader = csv.reader(f)
113+
l = [row for row in reader]
114+
header = l[0] # tag_id,name,category,count
115+
rows = l[1:]
116+
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
117+
118+
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
119+
120+
# 推論する
121+
def run_batch(path_imgs):
122+
imgs = np.array([im for _, im in path_imgs])
123+
124+
probs = model(imgs, training=False)
125+
probs = probs.numpy()
126+
127+
for (image_path, _), prob in zip(path_imgs, probs):
128+
# 最初の4つはratingなので無視する
129+
# # First 4 labels are actually ratings: pick one with argmax
130+
# ratings_names = label_names[:4]
131+
# rating_index = ratings_names["probs"].argmax()
132+
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
133+
134+
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
135+
# Everything else is tags: pick any where prediction confidence > threshold
136+
tag_text = ""
137+
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
138+
if p >= args.thresh and i < len(tags):
139+
tag_text += ", " + tags[i]
140+
141+
if len(tag_text) > 0:
142+
tag_text = tag_text[2:] # 最初の ", " を消す
143+
144+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
145+
f.write(tag_text + '\n')
146+
if args.debug:
147+
print(image_path, tag_text)
148+
149+
# 読み込みの高速化のためにDataLoaderを使うオプション
150+
if args.max_data_loader_n_workers is not None:
151+
dataset = ImageLoadingPrepDataset(image_paths)
152+
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
153+
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
154+
else:
155+
data = [[(None, ip)] for ip in image_paths]
156+
157+
b_imgs = []
158+
for data_entry in tqdm(data, smoothing=0.0):
159+
for data in data_entry:
160+
if data is None:
161+
continue
162+
163+
image, image_path = data
164+
if image is not None:
165+
image = image.detach().numpy()
166+
else:
167+
try:
168+
image = Image.open(image_path)
169+
if image.mode != 'RGB':
170+
image = image.convert("RGB")
171+
image = preprocess_image(image)
172+
except Exception as e:
173+
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
174+
continue
175+
b_imgs.append((image_path, image))
176+
177+
if len(b_imgs) >= args.batch_size:
178+
run_batch(b_imgs)
179+
b_imgs.clear()
180+
181+
if len(b_imgs) > 0:
182+
run_batch(b_imgs)
183+
184+
print("done!")
185+
186+
187+
def setup_parser() -> argparse.ArgumentParser:
188+
parser = argparse.ArgumentParser()
189+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
190+
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
191+
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
192+
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
193+
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
194+
parser.add_argument("--force_download", action='store_true',
195+
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
196+
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
197+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
198+
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
199+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
200+
parser.add_argument("--caption_extention", type=str, default=None,
201+
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
202+
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
203+
parser.add_argument("--debug", action="store_true", help="debug mode")
204+
205+
return parser
206+
207+
208+
if __name__ == '__main__':
209+
parser = setup_parser()
210+
211+
args = parser.parse_args()
212+
213+
# スペルミスしていたオプションを復元する
214+
if args.caption_extention is not None:
215+
args.caption_extension = args.caption_extention
216+
217+
main(args)

library/__init__.py

Whitespace-only changes.

library/wd14_caption_gui.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def caption_images(
3333
return
3434

3535
print(f'Captioning files in {train_data_dir}...')
36-
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"'
36+
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger_bmaltais.py"'
3737
run_cmd += f' --batch_size="{int(batch_size)}"'
3838
run_cmd += f' --thresh="{thresh}"'
3939
run_cmd += f' --caption_extension="{caption_extension}"'

setup.bat

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
@echo off
2-
python -m venv venv
2+
IF NOT EXIST venv (
3+
python -m venv venv
4+
) ELSE (
5+
echo venv folder already exists, skipping creation...
6+
)
37
call .\venv\Scripts\activate.bat
48

59
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
@@ -10,4 +14,4 @@ copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
1014
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
1115
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
1216

13-
accelerate config
17+
accelerate config

0 commit comments

Comments
 (0)