2
2
from PIL import Image
3
3
import numpy as np
4
4
import os
5
- from tqdm import tqdm
6
- import requests
7
5
import hashlib
8
6
from typing import List , Union
9
- import shutil
10
7
from pathlib import Path
11
8
9
+ from utils import download
10
+
12
11
13
12
def process_image (image : Image .Image ) -> np .ndarray :
14
13
"""
@@ -18,38 +17,9 @@ def process_image(image: Image.Image) -> np.ndarray:
18
17
"""
19
18
20
19
image = image .convert ("RGB" ).resize ((512 , 512 ))
21
- image = np .array (image ).astype (np .float32 ) / 255
22
- image = image .transpose ((2 , 0 , 1 )).reshape (1 , 3 , 512 , 512 ).transpose ((0 , 2 , 3 , 1 ))
23
- return image
24
-
25
-
26
- def download (url : str , save_path : str , md5 : str , length : str ) -> bool :
27
- """
28
- Download a file from url to save_path.
29
- If the file already exists, check its md5.
30
- If the md5 matches, return True,if the md5 doesn't match, return False.
31
- :param url: the url of the file to download
32
- :param save_path: the path to save the file
33
- :param md5: the md5 of the file
34
- :param length: the length of the file
35
- :return: True if the file is downloaded successfully, False otherwise
36
- """
37
-
38
- try :
39
- response = requests .get (url = url , stream = True )
40
- with open (save_path , "wb" ) as f :
41
- with tqdm .wrapattr (
42
- response .raw , "read" , total = length , desc = "Downloading"
43
- ) as r_raw :
44
- shutil .copyfileobj (r_raw , f )
45
- return (
46
- True
47
- if hashlib .md5 (open (save_path , "rb" ).read ()).hexdigest () == md5
48
- else False
49
- )
50
- except Exception as e :
51
- print (e )
52
- return False
20
+ imagenp = np .array (image ).astype (np .float32 ) / 255
21
+ imagenp = imagenp .transpose ((2 , 0 , 1 )).reshape (1 , 3 , 512 , 512 ).transpose ((0 , 2 , 3 , 1 ))
22
+ return imagenp
53
23
54
24
55
25
def download_model ():
@@ -109,7 +79,7 @@ def __init__(
109
79
):
110
80
"""
111
81
Initialize the DeepDanbooru class.
112
- :param mode: the mode of the model, "cpu" or "gpu " or "auto"
82
+ :param mode: the mode of the model, "cpu", "cuda", "hip " or "auto"
113
83
:param model_path: the path to the model file
114
84
:param tags_path: the path to the tags file
115
85
:param threshold: the threshold of the model
@@ -119,11 +89,13 @@ def __init__(
119
89
120
90
providers = {
121
91
"cpu" : "CPUExecutionProvider" ,
122
- "gpu" : "CUDAExecutionProvider" ,
92
+ "cuda" : "CUDAExecutionProvider" ,
93
+ "hip" : "ROCMExecutionProvider" ,
123
94
"tensorrt" : "TensorrtExecutionProvider" ,
124
95
"auto" : (
125
96
"CUDAExecutionProvider"
126
97
if "CUDAExecutionProvider" in ort .get_available_providers ()
98
+ else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort .get_available_providers ()
127
99
else "CPUExecutionProvider"
128
100
),
129
101
}
@@ -166,8 +138,8 @@ def __repr__(self) -> str:
166
138
return self .__str__ ()
167
139
168
140
def from_image_inference (self , image : Image .Image ) -> dict :
169
- image = process_image (image )
170
- return self .predict (image )
141
+ imagenp = process_image (image )
142
+ return self .predict (imagenp )
171
143
172
144
def from_ndarray_inferece (self , image : np .ndarray ) -> dict :
173
145
if image .shape != (1 , 512 , 512 , 3 ):
@@ -177,49 +149,6 @@ def from_ndarray_inferece(self, image: np.ndarray) -> dict:
177
149
def from_file_inference (self , image : str ) -> dict :
178
150
return self .from_image_inference (Image .open (image ))
179
151
180
- def from_list_inference (self , image : Union [list , tuple ]) -> List [dict ]:
181
- if self .pin_memory :
182
- image = [process_image (Image .open (i )) for i in image ]
183
- for i in [
184
- image [i : i + self .batch_size ]
185
- for i in range (0 , len (image ), self .batch_size )
186
- ]:
187
- imagelist = i
188
- bs = len (i )
189
- _imagelist , idx , hashlist = [], [], []
190
- for j in range (len (i )):
191
- img = Image .open (i [j ]) if not self .pin_memory else imagelist [j ]
192
- image_hash = hashlib .md5 (np .array (img ).astype (np .uint8 )).hexdigest ()
193
- hashlist .append (image_hash )
194
- if image_hash in self .cache :
195
- continue
196
- if not self .pin_memory :
197
- _imagelist .append (process_image (img ))
198
- else :
199
- _imagelist .append (imagelist [j ])
200
- idx .append (j )
201
-
202
- imagelist = _imagelist
203
- if len (imagelist ) != 0 :
204
- _image = np .vstack (imagelist )
205
- results = self .inference (_image )
206
- results_idx = 0
207
- else :
208
- results = []
209
-
210
- for i in range (bs ):
211
- image_tag = {}
212
- if i in idx :
213
- hash = hashlist [i ]
214
- for tag , score in zip (self .tags , results [results_idx ]):
215
- if score >= self .threshold :
216
- image_tag [tag ] = score
217
- results_idx += 1
218
- self .cache [hash ] = image_tag
219
- yield image_tag
220
- else :
221
- yield self .cache [hashlist [i ]]
222
-
223
152
def inference (self , image ):
224
153
return self .session .run (self .output_name , {self .input_name : image })[0 ]
225
154
@@ -236,8 +165,6 @@ def __call__(self, image) -> Union[dict, List[dict]]:
236
165
return self .from_file_inference (image )
237
166
elif isinstance (image , np .ndarray ):
238
167
return self .from_ndarray_inferece (image )
239
- elif isinstance (image , list ) or isinstance (image , tuple ):
240
- return self .from_list_inference (image )
241
168
elif isinstance (image , Image .Image ):
242
169
return self .from_image_inference (image )
243
170
else :
0 commit comments