-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathtest_voodoo.py
110 lines (87 loc) · 3.08 KB
/
test_voodoo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import time
import PIL
from nataili.inference.compvis.img2img import img2img
from nataili.inference.compvis.txt2img import txt2img
from nataili.model_manager import ModelManager
from nataili.util.cache import torch_gc
from nataili.util.logger import logger
init_image = "./01.png"
mm = ModelManager()
mm.init()
logger.debug("Available dependencies:")
for dependency in mm.available_dependencies:
logger.debug(dependency)
logger.debug("Available models:")
for model in mm.available_models:
logger.debug(model)
models_to_load = [
# 'stable_diffusion',
# 'waifu_diffusion',
"trinart",
# 'GFPGAN', 'RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B',
# 'BLIP', 'ViT-L/14', 'ViT-g-14', 'ViT-H-14'
]
logger.init(f"{models_to_load}", status="Loading")
@logger.catch
def test():
tic = time.time()
model = "safety_checker"
logger.init(f"Model: {model}", status="Loading")
success = mm.load_model(model)
toc = time.time()
logger.init_ok(f"Loading {model}: Took {toc-tic} seconds", status=success)
for model in models_to_load:
torch_gc()
tic = time.time()
logger.init(f"Model: {model}", status="Loading")
success = mm.load_model(model, use_voodoo=True)
toc = time.time()
logger.init_ok(f"Loading {model}: Took {toc-tic} seconds", status=success)
torch_gc()
if model in ["stable_diffusion", "waifu_diffusion", "trinart"]:
logger.debug(f"Running inference on {model}")
logger.info('Testing txt2img with prompt "collosal corgi"')
t2i = txt2img(
mm.loaded_models[model]["model"],
mm.loaded_models[model]["device"],
"test_output",
use_voodoo=True,
)
t2i.generate("collosal corgi")
torch_gc()
logger.info('Testing nsfw filter with prompt "boobs"')
t2i = txt2img(
mm.loaded_models[model]["model"],
mm.loaded_models[model]["device"],
"test_output",
filter_nsfw=True,
safety_checker=mm.loaded_models["safety_checker"]["model"],
use_voodoo=True,
)
t2i.generate("boobs")
torch_gc()
logger.info('Testing img2img with prompt "cute anime girl"')
i2i = img2img(
mm.loaded_models[model]["model"],
mm.loaded_models[model]["device"],
"test_output",
use_voodoo=True,
)
init_img = PIL.Image.open(init_image)
i2i.generate("cute anime girl", init_img)
torch_gc()
logger.init_ok(f"Model {model}", status="Unloading")
mm.unload_model(model)
torch_gc()
while True:
print("Enter model name to load:")
print(mm.available_models)
model = input()
if model == "exit":
break
print(f"Loading {model}")
success = mm.load_model(model)
print(f"Loading {model} successful: {success}")
print("")
if __name__ == "__main__":
test()