Skip to content

Commit

Permalink
faster indexing, SigLIP models
Browse files Browse the repository at this point in the history
  • Loading branch information
osmarks committed Oct 8, 2023
1 parent 2c9ce67 commit 46fca3e
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 111 deletions.
49 changes: 30 additions & 19 deletions clip_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch
import os
import time
import threading
from aiohttp import web
Expand All @@ -8,21 +8,34 @@
import umsgpack
import collections
import queue
import open_clip
from PIL import Image
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
import io
import json
import sys
import torch
from transformers import SiglipImageProcessor, T5Tokenizer, SiglipModel, SiglipConfig
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors import safe_open
import numpy

with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)

device = torch.device(CONFIG["device"])
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")
model.eval()
tokenizer = open_clip.get_tokenizer(CONFIG["model"])
print("Model loaded")
DEVICE = "cuda:0"

# So400m/14@384
with init_empty_weights():
model = SiglipModel(config=SiglipConfig.from_pretrained(CONFIG["model"])).half().eval()
with safe_open(os.path.join(CONFIG["model"], "model.safetensors"), framework="pt", device=DEVICE) as f:
for key in f.keys():
set_module_tensor_to_device(model, key, device=DEVICE, value=f.get_tensor(key))
model = model.to(DEVICE)
EMBDIM = model.config.vision_config.hidden_size # NOT projection_dim, why is that even there
RES = model.config.vision_config.image_size
tokenizer = T5Tokenizer(vocab_file=os.path.join(CONFIG["model"], "sentencepiece.model"), extra_ids=0, model_max_length=64, pad_token="</s>", legacy=False)
image_processor = SiglipImageProcessor(size={"height": RES, "width":RES})

BS = CONFIG["max_batch_size"]
MODELNAME = CONFIG["model_name"]
Expand All @@ -33,27 +46,24 @@
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])

torch.set_grad_enabled(False)
def do_inference(params: InferenceParameters):
with torch.no_grad():
try:
text, images, callback = params
if text is not None:
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
features = model.encode_text(text)
features = model.text_model.forward(input_ids=torch.tensor(text, device=DEVICE)).pooler_output
elif images is not None:
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
features = model.encode_image(images)
batch_count_ctr.labels(MODELNAME).inc()
features = model.vision_model.forward(torch.tensor(images, device=DEVICE)).pooler_output
features /= features.norm(dim=-1, keepdim=True)
batch_count_ctr.labels(MODELNAME).inc()
callback(True, features.cpu().numpy())
except Exception as e:
traceback.print_exc()
callback(False, str(e))
finally:
torch.cuda.empty_cache()

iq = queue.Queue(10)
def infer_thread():
Expand All @@ -67,10 +77,11 @@ def preprocessing_thread():
try:
if text:
assert len(text) <= BS, f"max batch size is {BS}"
text = tokenizer(text).to(device)
# I feel like this ought to be batchable but I can't see how to do that
text = numpy.array(tokenizer(text, padding="max_length", truncation=True)["input_ids"])
elif images:
assert len(images) <= BS, f"max batch size is {BS}"
images = torch.stack([ preprocess(Image.open(io.BytesIO(im))).half() for im in images ]).to(device)
images = numpy.array(image_processor([ Image.open(io.BytesIO(bs)) for bs in images ])["pixel_values"]).astype("float16")
else:
assert False, "images or text required"
iq.put(InferenceParameters(text, images, callback))
Expand Down Expand Up @@ -105,10 +116,10 @@ def callback(*argv):
@routes.get("/config")
async def config(request):
return web.Response(body=umsgpack.dumps({
"model": CONFIG["model"],
"model": MODELNAME,
"batch": BS,
"image_size": model.visual.image_size,
"embedding_size": model.visual.output_dim
"image_size": (RES, RES),
"embedding_size": EMBDIM
}), status=200, content_type="application/msgpack")

@routes.get("/")
Expand Down
5 changes: 2 additions & 3 deletions clip_server_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"device": "cuda:0",
"model": "ViT-H-14",
"model_name": "openclip-ViT-H-14",
"model": "./out",
"model_name": "siglip-so400m/14@384",
"max_batch_size": 128,
"port": 1708
}
199 changes: 110 additions & 89 deletions mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import json
import io
import sys
from concurrent.futures import ProcessPoolExecutor

with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
Expand All @@ -36,7 +37,8 @@ async def run_query(request):
data = await request.json()
embeddings = []
if images := data.get("images", []):
embeddings.extend(await clip_server({ "images": [ base64.b64decode(x) for x, w in images ] }))
target_image_size = app["index"].inference_server_config["image_size"]
embeddings.extend(await clip_server({ "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
if text := data.get("text", []):
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
weights = [ w for x, w in images ] + [ w for x, w in text ]
Expand All @@ -54,6 +56,13 @@ async def reload_index_route(request):
await request.app["index"].reload()
return web.json_response(True)

def load_image(path, image_size):
im = Image.open(path)
im.draft("RGB", image_size)
buf = io.BytesIO()
im.resize(image_size).convert("RGB").save(buf, format="BMP")
return buf.getvalue(), path

class Index:
def __init__(self, inference_server_config):
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
Expand All @@ -72,96 +81,107 @@ def search(self, query):

async def reload(self):
async with self.lock:
print("Indexing")
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
conn.row_factory = aiosqlite.Row
await conn.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL
);
""")
try:
async with asyncio.TaskGroup() as tg:
batch_sem = asyncio.Semaphore(3)

modified = set()

async def do_batch(batch):
try:
query = { "images": [ arg[2] for arg in batch ] }
embeddings = await clip_server(query, False)
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
])
await conn.commit()
for filename, _, _ in batch:
modified.add(filename)
sys.stdout.write(".")
finally:
batch_sem.release()

async def dispatch_batch(batch):
await batch_sem.acquire()
tg.create_task(do_batch(batch))

files = {}
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
files[filename] = modtime
await conn.commit()
batch = []

for dirpath, _, filenames in os.walk(CONFIG["files"]):
for file in filenames:
path = os.path.join(dirpath, file)
file = os.path.relpath(path, CONFIG["files"])
st = os.stat(path)
if st.st_mtime != files.get(file):
with ProcessPoolExecutor(max_workers=12) as executor:
print("Indexing")
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
conn.row_factory = aiosqlite.Row
await conn.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL
);
""")
try:
async with asyncio.TaskGroup() as tg:
batch_sem = asyncio.Semaphore(3)

modified = set()

async def do_batch(batch):
try:
query = { "images": [ arg[2] for arg in batch ] }
embeddings = await clip_server(query, False)
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
])
await conn.commit()
for filename, _, _ in batch:
modified.add(filename)
sys.stdout.write(".")
sys.stdout.flush()
finally:
batch_sem.release()

async def dispatch_batch(batch):
await batch_sem.acquire()
tg.create_task(do_batch(batch))

files = {}
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
files[filename] = modtime
await conn.commit()
batch = []

failed = set()
for dirpath, _, filenames in os.walk(CONFIG["files"]):
paths = set()
done = set()
for file in filenames:
path = os.path.join(dirpath, file)
file = os.path.relpath(path, CONFIG["files"])
st = os.stat(path)
if st.st_mtime != files.get(file):
paths.add(path)
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
try:
im = Image.open(path)
im.draft("RGB", self.inference_server_config["image_size"])
buf = io.BytesIO()
im.resize(self.inference_server_config["image_size"]).convert("RGB").save(buf, format="BMP")
b = buf.getvalue()
b, path = await task
st = os.stat(path)
file = os.path.relpath(path, CONFIG["files"])
done.add(path)
except Exception as e:
print(file, "failed", e)
# print(file, "failed", e) we can't have access to file when we need it, oops
continue
batch.append((file, st.st_mtime, b))
if len(batch) % self.inference_server_config["batch"] == self.inference_server_config["batch"] - 1:
if len(batch) == self.inference_server_config["batch"]:
await dispatch_batch(batch)
batch = []
if batch:
await dispatch_batch(batch)

remove_indices = []
for index, filename in enumerate(self.associated_filenames):
if filename not in files or filename in modified:
remove_indices.append(index)
self.associated_filenames[index] = None
if filename not in files:
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
await conn.commit()
# TODO concurrency
# TODO understand what that comment meant
if remove_indices:
self.faiss_index.remove_ids(numpy.array(remove_indices))
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]

filenames_set = set(self.associated_filenames)
new_data = []
new_filenames = []
async with conn.execute("SELECT * FROM files") as csr:
while row := await csr.fetchone():
filename, modtime, embedding_vector = row
if filename not in filenames_set:
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
new_filenames.append(filename)
new_data = numpy.array(new_data)
self.associated_filenames.extend(new_filenames)
self.faiss_index.add(new_data)
finally:
await conn.close()
failed |= paths - done
if batch:
await dispatch_batch(batch)

print()
for failed_ in failed:
print(failed_, "failed")

remove_indices = []
for index, filename in enumerate(self.associated_filenames):
if filename not in files or filename in modified:
remove_indices.append(index)
self.associated_filenames[index] = None
if filename not in files:
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
await conn.commit()
# TODO concurrency
# TODO understand what that comment meant
if remove_indices:
self.faiss_index.remove_ids(numpy.array(remove_indices))
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]

filenames_set = set(self.associated_filenames)
new_data = []
new_filenames = []
async with conn.execute("SELECT * FROM files") as csr:
while row := await csr.fetchone():
filename, modtime, embedding_vector = row
if filename not in filenames_set:
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
new_filenames.append(filename)
new_data = numpy.array(new_data)
self.associated_filenames.extend(new_filenames)
self.faiss_index.add(new_data)
finally:
await conn.close()

app.router.add_routes(routes)

Expand Down Expand Up @@ -195,7 +215,8 @@ async def main():
site = web.TCPSite(runner, "", CONFIG["port"])
await site.start()

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.run_forever()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.run_forever()

0 comments on commit 46fca3e

Please sign in to comment.