Skip to content

Commit 422debd

Browse files
committed
DanbooruTagger: add wd tagger support
1 parent 2a6908c commit 422debd

File tree

7 files changed

+305
-109
lines changed

7 files changed

+305
-109
lines changed

DanbooruTagger/DanbooruTagger.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import warnings
21
from deepdanbooru_onnx import DeepDanbooru
2+
from wd_onnx import Wd
33
from PIL import Image
44
import argparse
55
import cv2
@@ -13,13 +13,20 @@
1313

1414

1515
def find_image_files(path: str) -> list[str]:
16-
paths = list()
17-
for root, dirs, files in os.walk(path):
18-
for filename in files:
19-
name, extension = os.path.splitext(filename)
20-
if extension.lower() in image_ext_ocv:
21-
paths.append(os.path.join(root, filename))
22-
return paths
16+
if os.path.isdir(path):
17+
paths = list()
18+
for root, dirs, files in os.walk(path):
19+
for filename in files:
20+
name, extension = os.path.splitext(filename)
21+
if extension.lower() in image_ext_ocv:
22+
paths.append(os.path.join(root, filename))
23+
return paths
24+
else:
25+
name, extension = os.path.splitext(path)
26+
if extension.lower() in image_ext_ocv:
27+
return [path]
28+
else:
29+
return []
2330

2431

2532
def image_loader(paths: list[str]):
@@ -35,14 +42,28 @@ def image_loader(paths: list[str]):
3542
yield image_pil, path
3643

3744

38-
def pipeline(queue: Queue, image_paths: list[str], device: int):
39-
danbooru = DeepDanbooru()
45+
def danbooru_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool):
46+
danbooru = DeepDanbooru("cpu" if cpu else "auto")
4047

4148
for path in image_paths:
4249
imageprompt = ""
4350
tags = danbooru(path)
4451
for tag in tags:
4552
imageprompt = imageprompt + ", " + tag
53+
imageprompt = imageprompt[2:]
54+
55+
queue.put({"file_name": path, "text": imageprompt})
56+
57+
58+
def wd_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool):
59+
wd = Wd("cpu" if cpu else "auto", threshold=0.3)
60+
61+
for path in image_paths:
62+
imageprompt = ""
63+
tags = wd(path)
64+
for tag in tags:
65+
imageprompt = imageprompt + ", " + tag
66+
imageprompt = imageprompt[2:]
4667

4768
queue.put({"file_name": path, "text": imageprompt})
4869

@@ -57,49 +78,71 @@ def split_list(input_list, count):
5778
def save_meta(meta_file, meta, reldir, common_description):
5879
meta["file_name"] = os.path.relpath(meta["file_name"], reldir)
5980
if common_description is not None:
60-
meta["text"] = common_description + meta["text"]
81+
meta["text"] = common_description + ", " + meta["text"]
6182
meta_file.write(json.dumps(meta) + '\n')
6283

6384

6485
if __name__ == "__main__":
6586
parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru")
66-
parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference")
6787
parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one")
68-
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag")
88+
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag or a singular image to tag")
89+
parser.add_argument('--wd', '-w', action="store_true", help="use wd tagger instead of DeepDanbooru")
90+
parser.add_argument('--cpu', action="store_true", help="force cpu usge instead of gpu")
6991
args = parser.parse_args()
7092

71-
nparalell = 2
72-
7393
image_paths = find_image_files(args.image_dir)
94+
95+
if len(image_paths) == 0:
96+
print("Unable to find any images at {args.image_dir}")
97+
exit(1)
98+
99+
nparalell = 4 if len(image_paths) > 4 else len(image_paths)
74100
image_path_chunks = list(split_list(image_paths, nparalell))
75101

76-
print(f"Will use {nparalell} processies to create tags")
102+
print(f"Will use {nparalell} processies to create tags for {len(image_paths)} images")
77103

78104
queue = Queue()
105+
pipe = danbooru_pipeline if not args.wd else wd_pipeline
79106
processies = list()
80107
for i in range(0, nparalell):
81-
processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i)))
108+
processies.append(Process(target=pipe, args=(queue, image_path_chunks[i], i, args.cpu)))
82109
processies[-1].start()
83110

84111
progress = tqdm(desc="Generateing tags", total=len(image_paths))
85112
exit = False
86-
with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file:
113+
114+
if len(image_paths) > 1:
115+
with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file:
116+
while not exit:
117+
if not queue.empty():
118+
meta = queue.get()
119+
save_meta(output_file, meta, args.image_dir, args.common_description)
120+
progress.update()
121+
exit = True
122+
for process in processies:
123+
if process.is_alive():
124+
exit = False
125+
break
126+
127+
while not queue.empty():
128+
meta = queue.get()
129+
save_meta(output_file, meta, args.image_dir, args.common_description)
130+
progress.update()
131+
else:
87132
while not exit:
88133
if not queue.empty():
89134
meta = queue.get()
90-
save_meta(output_file, meta, args.image_dir, args.common_description)
135+
print(meta)
91136
progress.update()
92137
exit = True
93138
for process in processies:
94139
if process.is_alive():
95140
exit = False
96141
break
97-
98142
while not queue.empty():
99143
meta = queue.get()
100-
save_meta(output_file, meta, args.image_dir, args.common_description)
144+
print(meta)
101145
progress.update()
102146

103147
for process in processies:
104148
process.join()
105-

DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py renamed to DanbooruTagger/deepdanbooru_onnx.py

Lines changed: 11 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from PIL import Image
33
import numpy as np
44
import os
5-
from tqdm import tqdm
6-
import requests
75
import hashlib
86
from typing import List, Union
9-
import shutil
107
from pathlib import Path
118

9+
from utils import download
10+
1211

1312
def process_image(image: Image.Image) -> np.ndarray:
1413
"""
@@ -18,38 +17,9 @@ def process_image(image: Image.Image) -> np.ndarray:
1817
"""
1918

2019
image = image.convert("RGB").resize((512, 512))
21-
image = np.array(image).astype(np.float32) / 255
22-
image = image.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1))
23-
return image
24-
25-
26-
def download(url: str, save_path: str, md5: str, length: str) -> bool:
27-
"""
28-
Download a file from url to save_path.
29-
If the file already exists, check its md5.
30-
If the md5 matches, return True,if the md5 doesn't match, return False.
31-
:param url: the url of the file to download
32-
:param save_path: the path to save the file
33-
:param md5: the md5 of the file
34-
:param length: the length of the file
35-
:return: True if the file is downloaded successfully, False otherwise
36-
"""
37-
38-
try:
39-
response = requests.get(url=url, stream=True)
40-
with open(save_path, "wb") as f:
41-
with tqdm.wrapattr(
42-
response.raw, "read", total=length, desc="Downloading"
43-
) as r_raw:
44-
shutil.copyfileobj(r_raw, f)
45-
return (
46-
True
47-
if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5
48-
else False
49-
)
50-
except Exception as e:
51-
print(e)
52-
return False
20+
imagenp = np.array(image).astype(np.float32) / 255
21+
imagenp = imagenp.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1))
22+
return imagenp
5323

5424

5525
def download_model():
@@ -109,7 +79,7 @@ def __init__(
10979
):
11080
"""
11181
Initialize the DeepDanbooru class.
112-
:param mode: the mode of the model, "cpu" or "gpu" or "auto"
82+
:param mode: the mode of the model, "cpu", "cuda", "hip" or "auto"
11383
:param model_path: the path to the model file
11484
:param tags_path: the path to the tags file
11585
:param threshold: the threshold of the model
@@ -119,11 +89,13 @@ def __init__(
11989

12090
providers = {
12191
"cpu": "CPUExecutionProvider",
122-
"gpu": "CUDAExecutionProvider",
92+
"cuda": "CUDAExecutionProvider",
93+
"hip": "ROCMExecutionProvider",
12394
"tensorrt": "TensorrtExecutionProvider",
12495
"auto": (
12596
"CUDAExecutionProvider"
12697
if "CUDAExecutionProvider" in ort.get_available_providers()
98+
else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers()
12799
else "CPUExecutionProvider"
128100
),
129101
}
@@ -166,8 +138,8 @@ def __repr__(self) -> str:
166138
return self.__str__()
167139

168140
def from_image_inference(self, image: Image.Image) -> dict:
169-
image = process_image(image)
170-
return self.predict(image)
141+
imagenp = process_image(image)
142+
return self.predict(imagenp)
171143

172144
def from_ndarray_inferece(self, image: np.ndarray) -> dict:
173145
if image.shape != (1, 512, 512, 3):
@@ -177,49 +149,6 @@ def from_ndarray_inferece(self, image: np.ndarray) -> dict:
177149
def from_file_inference(self, image: str) -> dict:
178150
return self.from_image_inference(Image.open(image))
179151

180-
def from_list_inference(self, image: Union[list, tuple]) -> List[dict]:
181-
if self.pin_memory:
182-
image = [process_image(Image.open(i)) for i in image]
183-
for i in [
184-
image[i : i + self.batch_size]
185-
for i in range(0, len(image), self.batch_size)
186-
]:
187-
imagelist = i
188-
bs = len(i)
189-
_imagelist, idx, hashlist = [], [], []
190-
for j in range(len(i)):
191-
img = Image.open(i[j]) if not self.pin_memory else imagelist[j]
192-
image_hash = hashlib.md5(np.array(img).astype(np.uint8)).hexdigest()
193-
hashlist.append(image_hash)
194-
if image_hash in self.cache:
195-
continue
196-
if not self.pin_memory:
197-
_imagelist.append(process_image(img))
198-
else:
199-
_imagelist.append(imagelist[j])
200-
idx.append(j)
201-
202-
imagelist = _imagelist
203-
if len(imagelist) != 0:
204-
_image = np.vstack(imagelist)
205-
results = self.inference(_image)
206-
results_idx = 0
207-
else:
208-
results = []
209-
210-
for i in range(bs):
211-
image_tag = {}
212-
if i in idx:
213-
hash = hashlist[i]
214-
for tag, score in zip(self.tags, results[results_idx]):
215-
if score >= self.threshold:
216-
image_tag[tag] = score
217-
results_idx += 1
218-
self.cache[hash] = image_tag
219-
yield image_tag
220-
else:
221-
yield self.cache[hashlist[i]]
222-
223152
def inference(self, image):
224153
return self.session.run(self.output_name, {self.input_name: image})[0]
225154

@@ -236,8 +165,6 @@ def __call__(self, image) -> Union[dict, List[dict]]:
236165
return self.from_file_inference(image)
237166
elif isinstance(image, np.ndarray):
238167
return self.from_ndarray_inferece(image)
239-
elif isinstance(image, list) or isinstance(image, tuple):
240-
return self.from_list_inference(image)
241168
elif isinstance(image, Image.Image):
242169
return self.from_image_inference(image)
243170
else:

DanbooruTagger/deepdanbooru_onnx/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.
Binary file not shown.
Binary file not shown.

DanbooruTagger/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import requests
2+
import shutil
3+
import hashlib
4+
from tqdm import tqdm
5+
6+
7+
def download(url: str, save_path: str, md5: str, length: str) -> bool:
8+
"""
9+
Download a file from url to save_path.
10+
If the file already exists, check its md5.
11+
If the md5 matches, return True,if the md5 doesn't match, return False.
12+
:param url: the url of the file to download
13+
:param save_path: the path to save the file
14+
:param md5: the md5 of the file
15+
:param length: the length of the file
16+
:return: True if the file is downloaded successfully, False otherwise
17+
"""
18+
19+
try:
20+
response = requests.get(url=url, stream=True)
21+
with open(save_path, "wb") as f:
22+
with tqdm.wrapattr(
23+
response.raw, "read", total=length, desc="Downloading"
24+
) as r_raw:
25+
shutil.copyfileobj(r_raw, f)
26+
return (
27+
True
28+
if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5
29+
else False
30+
)
31+
except Exception as e:
32+
print(e)
33+
return False

0 commit comments

Comments
 (0)