-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathclip_server.py
209 lines (185 loc) · 7.75 KB
/
clip_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
import time
import threading
from aiohttp import web
import aiohttp
import asyncio
import traceback
import msgpack
import collections
import queue
import open_clip
from PIL import Image
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
import io
import json
import torchvision.transforms.transforms as transforms
import sys
with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
device = torch.device(CONFIG["device"])
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")
model.eval()
tokenizer = open_clip.get_tokenizer(CONFIG["model"])
print("Model loaded")
BS = CONFIG["max_batch_size"]
MODELNAME = CONFIG["model_name"]
fast_image_fns = {}
# ugly hack, sorry
if CONFIG.get("aitemplate_image_models"):
from aitemplate.compiler import Model
from aitemplate.testing import detect_target
USE_CUDA = detect_target().name() == "cuda"
state = model.state_dict()
conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
def load_pretrained():
params = {}
for key, value in state.items():
orig_key = key
if key.startswith("visual."):
key = key.removeprefix("visual.") \
.replace("trunk.patch_embed", "patch_embed") \
.replace("trunk.blocks", "encoder.layers") \
.replace(".attn.", ".mha.") \
.replace(".norm1.", ".ln1.") \
.replace(".norm2.", ".ln2.") \
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
.replace("trunk.norm.", "encoder.ln.") \
.replace("trunk.attn_pool.latent", "pool.probe") \
.replace("trunk.attn_pool", "pool") \
.replace("pool.norm", "pool.ln")
if "patch_embed.proj.weight" not in key:
params[key.replace(".", "_")] = value.cuda()
#print(orig_key, key.replace(".", "_"))
params["patch_embed_proj_weight"] = conv_weights
return params
def generate_wrapper(path):
ait_model = Model(path)
ait_model.set_many_constants_with_tensors(load_pretrained())
ait_model.fold_constants(sync=True)
def wrapper(batch):
xs = [batch.permute((0, 2, 3, 1)).contiguous()]
ys = []
for i in range(len(ait_model.get_output_name_to_index_map())):
shape = ait_model.get_output_maximum_shape(i)
ys.append(torch.empty(shape).cuda().half())
ait_model.run_with_tensors(xs, ys)
return ys[0][:, 0, :]
return wrapper
for batch_size, path in CONFIG["aitemplate_image_models"]:
fast_image_fns[batch_size] = generate_wrapper(path)
print("loaded", batch_size, path)
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
torch.set_grad_enabled(False)
def do_inference(params: InferenceParameters):
with torch.no_grad():
try:
text, images, callback = params
if text is not None:
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
features = model.encode_text(text)
features /= features.norm(dim=-1, keepdim=True)
features = features.cpu().numpy()
elif images is not None:
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
batch = images.shape[0]
if fast_image_fns:
progress = 0
features = torch.zeros((batch, model.text.text_projection.out_features))
while progress < batch:
biggest_available = max(x for x in fast_image_fns.keys() if x <= (batch - progress))
chunk = fast_image_fns[biggest_available](images[progress:progress + biggest_available])
features[progress:progress + biggest_available] = chunk
progress += biggest_available
else:
features = model.encode_image(images)
features /= features.norm(dim=-1, keepdim=True)
features = features.cpu().numpy()
batch_count_ctr.labels(MODELNAME).inc()
callback(True, features)
except Exception as e:
traceback.print_exc()
callback(False, str(e))
finally:
torch.cuda.empty_cache()
iq = queue.Queue(10)
def infer_thread():
while True:
do_inference(iq.get())
pq = queue.Queue(10)
def preprocessing_thread():
while True:
text, images, callback = pq.get()
try:
if text:
assert len(text) <= BS, f"max batch size is {BS}"
text = tokenizer(text).to(device)
elif images:
assert len(images) <= BS, f"max batch size is {BS}"
images = torch.stack([ preprocess(Image.open(io.BytesIO(im))).half() for im in images ]).to(device)
else:
assert False, "images or text required"
iq.put(InferenceParameters(text, images, callback))
except Exception as e:
traceback.print_exc()
callback(False, str(e))
app = web.Application(client_max_size=2**26)
routes = web.RouteTableDef()
@routes.post("/")
async def run_inference(request):
loop = asyncio.get_event_loop()
data = msgpack.loads(await request.read())
event = asyncio.Event()
results = None
def callback(*argv):
nonlocal results
results = argv
loop.call_soon_threadsafe(lambda: event.set())
pq.put_nowait(InferenceParameters(data.get("text"), data.get("images"), callback))
await event.wait()
body_data = results[1]
if results[0]:
status = 200
body_data = [x.astype("float16").tobytes() for x in body_data]
else:
status = 500
print(results[1])
return web.Response(body=msgpack.dumps(body_data), status=status, content_type="application/msgpack")
@routes.get("/config")
async def config(request):
return web.Response(body=msgpack.dumps({
"model": CONFIG["model"],
"batch": BS,
"image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size,
"embedding_size": model.text.text_projection.out_features
}), status=200, content_type="application/msgpack")
@routes.get("/")
async def health(request):
return web.Response(status=204)
@routes.get("/metrics")
async def metrics(request):
return web.Response(body=generate_latest(REGISTRY))
app.router.add_routes(routes)
async def run_webserver():
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
print("Ready")
await site.start()
try:
th = threading.Thread(target=infer_thread)
th.start()
th = threading.Thread(target=preprocessing_thread)
th.start()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_webserver())
loop.run_forever()
except KeyboardInterrupt:
import sys
sys.exit(0)