-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathstable_diffusion_tpu_req.py
132 lines (116 loc) · 4.29 KB
/
stable_diffusion_tpu_req.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import argparse
from concurrent import futures
import functools
from io import BytesIO
import numpy as np
from PIL import Image
import requests
from tqdm import tqdm
_PROMPTS = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
"A cat sitting on a windowsill",
"A dog playing fetch in a park",
"A city skyline at night",
"A field of flowers in bloom",
"A tropical beach with palm trees",
"A snowy mountain range",
"A waterfall cascading into a pool",
"A forest at sunset",
"A desert landscape with cacti",
"A volcano erupting",
"A lightning storm in the distance",
"A rainbow over a rainbow",
"A unicorn grazing in a meadow",
"A dragon flying through the sky",
"A mermaid swimming in the ocean",
"A robot walking down the street",
"A UFO landing in a field",
"A portal to another dimension",
"A time traveler from the future",
"A talking cat",
"A bowl of fruit on a table",
"A group of friends laughing",
"A family sitting down for dinner",
"A couple kissing in the rain",
"A child playing with a toy",
"A musician playing an instrument",
"A painter painting a picture",
"A writer writing a book",
"A scientist conducting an experiment",
"A construction worker building a house",
"A doctor operating on a patient",
"A teacher teaching a class",
"A police officer arresting a suspect",
"A firefighter putting out a fire",
"A soldier fighting in a war",
"A farmer working in a field",
"A pilot flying a plane",
"An astronaut in space",
"A unicorn eating a rainbow"
]
def send_request_and_receive_image(prompt: str, url: str) -> BytesIO:
"""Sends a single prompt request and returns the Image."""
try:
inputs = "%20".join(prompt.split(" "))
resp = requests.get(f"{url}?prompt={inputs}")
resp.raise_for_status()
return BytesIO(resp.content)
except requests.RequestException as e:
print(f"An error occurred while sending the request: {e}")
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def send_requests(num_requests: int, batch_size: int, save_pictures: bool,
url: str = "http://localhost:8000/imagine"):
"""Sends a list of requests and processes the responses."""
print("num_requests: ", num_requests)
print("batch_size: ", batch_size)
print("url: ", url)
print("save_pictures: ", save_pictures)
prompts = _PROMPTS
if num_requests > len(_PROMPTS):
# Repeat until larger than num_requests
prompts = _PROMPTS * int(np.ceil(num_requests / len(_PROMPTS)))
prompts = np.random.choice(
prompts, num_requests, replace=False)
with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
raw_images = list(
tqdm(
executor.map(
functools.partial(send_request_and_receive_image, url=url),
prompts,
),
total=len(prompts),
)
)
if save_pictures:
print("Saving pictures to diffusion_results.png")
images = [Image.open(raw_image) for raw_image in raw_images]
grid = image_grid(images, 2, num_requests // 2)
grid.save("./diffusion_results.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sends requests to Diffusion.")
parser.add_argument(
"--num_requests", help="Number of requests to send.",
default=8)
parser.add_argument(
"--batch_size", help="The number of requests to send at a time.",
default=8)
parser.add_argument(
"--save_pictures", default=False, action="store_true",
help="Whether to save the generated pictures to disk.")
parser.add_argument(
"--ip", help="The IP address to send the requests to.")
args = parser.parse_args()
send_requests(
num_requests=int(args.num_requests), batch_size=int(args.batch_size),
save_pictures=bool(args.save_pictures))