diff --git a/clip_server.py b/clip_server.py index 0715958..c937a54 100644 --- a/clip_server.py +++ b/clip_server.py @@ -1,4 +1,4 @@ -import torch +import os import time import threading from aiohttp import web @@ -8,21 +8,34 @@ import umsgpack import collections import queue -import open_clip from PIL import Image from prometheus_client import Counter, Histogram, REGISTRY, generate_latest import io import json import sys +import torch +from transformers import SiglipImageProcessor, T5Tokenizer, SiglipModel, SiglipConfig +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device +from safetensors import safe_open +import numpy with open(sys.argv[1], "r") as config_file: CONFIG = json.load(config_file) -device = torch.device(CONFIG["device"]) -model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16") -model.eval() -tokenizer = open_clip.get_tokenizer(CONFIG["model"]) -print("Model loaded") +DEVICE = "cuda:0" + +# So400m/14@384 +with init_empty_weights(): + model = SiglipModel(config=SiglipConfig.from_pretrained(CONFIG["model"])).half().eval() +with safe_open(os.path.join(CONFIG["model"], "model.safetensors"), framework="pt", device=DEVICE) as f: + for key in f.keys(): + set_module_tensor_to_device(model, key, device=DEVICE, value=f.get_tensor(key)) +model = model.to(DEVICE) +EMBDIM = model.config.vision_config.hidden_size # NOT projection_dim, why is that even there +RES = model.config.vision_config.image_size +tokenizer = T5Tokenizer(vocab_file=os.path.join(CONFIG["model"], "sentencepiece.model"), extra_ids=0, model_max_length=64, pad_token="", legacy=False) +image_processor = SiglipImageProcessor(size={"height": RES, "width":RES}) BS = CONFIG["max_batch_size"] MODELNAME = CONFIG["model_name"] @@ -33,7 +46,6 @@ inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"]) batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"]) -torch.set_grad_enabled(False) def do_inference(params: InferenceParameters): with torch.no_grad(): try: @@ -41,19 +53,17 @@ def do_inference(params: InferenceParameters): if text is not None: items_ctr.labels(MODELNAME, "text").inc(text.shape[0]) with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time(): - features = model.encode_text(text) + features = model.text_model.forward(input_ids=torch.tensor(text, device=DEVICE)).pooler_output elif images is not None: + items_ctr.labels(MODELNAME, "image").inc(images.shape[0]) with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time(): - items_ctr.labels(MODELNAME, "image").inc(images.shape[0]) - features = model.encode_image(images) - batch_count_ctr.labels(MODELNAME).inc() + features = model.vision_model.forward(torch.tensor(images, device=DEVICE)).pooler_output features /= features.norm(dim=-1, keepdim=True) + batch_count_ctr.labels(MODELNAME).inc() callback(True, features.cpu().numpy()) except Exception as e: traceback.print_exc() callback(False, str(e)) - finally: - torch.cuda.empty_cache() iq = queue.Queue(10) def infer_thread(): @@ -67,10 +77,11 @@ def preprocessing_thread(): try: if text: assert len(text) <= BS, f"max batch size is {BS}" - text = tokenizer(text).to(device) + # I feel like this ought to be batchable but I can't see how to do that + text = numpy.array(tokenizer(text, padding="max_length", truncation=True)["input_ids"]) elif images: assert len(images) <= BS, f"max batch size is {BS}" - images = torch.stack([ preprocess(Image.open(io.BytesIO(im))).half() for im in images ]).to(device) + images = numpy.array(image_processor([ Image.open(io.BytesIO(bs)) for bs in images ])["pixel_values"]).astype("float16") else: assert False, "images or text required" iq.put(InferenceParameters(text, images, callback)) @@ -105,10 +116,10 @@ def callback(*argv): @routes.get("/config") async def config(request): return web.Response(body=umsgpack.dumps({ - "model": CONFIG["model"], + "model": MODELNAME, "batch": BS, - "image_size": model.visual.image_size, - "embedding_size": model.visual.output_dim + "image_size": (RES, RES), + "embedding_size": EMBDIM }), status=200, content_type="application/msgpack") @routes.get("/") diff --git a/clip_server_config.json b/clip_server_config.json index 251d312..f3d401a 100644 --- a/clip_server_config.json +++ b/clip_server_config.json @@ -1,7 +1,6 @@ { - "device": "cuda:0", - "model": "ViT-H-14", - "model_name": "openclip-ViT-H-14", + "model": "./out", + "model_name": "siglip-so400m/14@384", "max_batch_size": 128, "port": 1708 } \ No newline at end of file diff --git a/mse.py b/mse.py index 37a7c85..4c60cb2 100644 --- a/mse.py +++ b/mse.py @@ -13,6 +13,7 @@ import json import io import sys +from concurrent.futures import ProcessPoolExecutor with open(sys.argv[1], "r") as config_file: CONFIG = json.load(config_file) @@ -36,7 +37,8 @@ async def run_query(request): data = await request.json() embeddings = [] if images := data.get("images", []): - embeddings.extend(await clip_server({ "images": [ base64.b64decode(x) for x, w in images ] })) + target_image_size = app["index"].inference_server_config["image_size"] + embeddings.extend(await clip_server({ "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] })) if text := data.get("text", []): embeddings.extend(await clip_server({ "text": [ x for x, w in text ] })) weights = [ w for x, w in images ] + [ w for x, w in text ] @@ -54,6 +56,13 @@ async def reload_index_route(request): await request.app["index"].reload() return web.json_response(True) +def load_image(path, image_size): + im = Image.open(path) + im.draft("RGB", image_size) + buf = io.BytesIO() + im.resize(image_size).convert("RGB").save(buf, format="BMP") + return buf.getvalue(), path + class Index: def __init__(self, inference_server_config): self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"]) @@ -72,96 +81,107 @@ def search(self, query): async def reload(self): async with self.lock: - print("Indexing") - conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop()) - conn.row_factory = aiosqlite.Row - await conn.executescript(""" - CREATE TABLE IF NOT EXISTS files ( - filename TEXT PRIMARY KEY, - modtime REAL NOT NULL, - embedding_vector BLOB NOT NULL - ); - """) - try: - async with asyncio.TaskGroup() as tg: - batch_sem = asyncio.Semaphore(3) - - modified = set() - - async def do_batch(batch): - try: - query = { "images": [ arg[2] for arg in batch ] } - embeddings = await clip_server(query, False) - await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [ - (filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings) - ]) - await conn.commit() - for filename, _, _ in batch: - modified.add(filename) - sys.stdout.write(".") - finally: - batch_sem.release() - - async def dispatch_batch(batch): - await batch_sem.acquire() - tg.create_task(do_batch(batch)) - - files = {} - for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"): - files[filename] = modtime - await conn.commit() - batch = [] - - for dirpath, _, filenames in os.walk(CONFIG["files"]): - for file in filenames: - path = os.path.join(dirpath, file) - file = os.path.relpath(path, CONFIG["files"]) - st = os.stat(path) - if st.st_mtime != files.get(file): + with ProcessPoolExecutor(max_workers=12) as executor: + print("Indexing") + conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop()) + conn.row_factory = aiosqlite.Row + await conn.executescript(""" + CREATE TABLE IF NOT EXISTS files ( + filename TEXT PRIMARY KEY, + modtime REAL NOT NULL, + embedding_vector BLOB NOT NULL + ); + """) + try: + async with asyncio.TaskGroup() as tg: + batch_sem = asyncio.Semaphore(3) + + modified = set() + + async def do_batch(batch): + try: + query = { "images": [ arg[2] for arg in batch ] } + embeddings = await clip_server(query, False) + await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [ + (filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings) + ]) + await conn.commit() + for filename, _, _ in batch: + modified.add(filename) + sys.stdout.write(".") + sys.stdout.flush() + finally: + batch_sem.release() + + async def dispatch_batch(batch): + await batch_sem.acquire() + tg.create_task(do_batch(batch)) + + files = {} + for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"): + files[filename] = modtime + await conn.commit() + batch = [] + + failed = set() + for dirpath, _, filenames in os.walk(CONFIG["files"]): + paths = set() + done = set() + for file in filenames: + path = os.path.join(dirpath, file) + file = os.path.relpath(path, CONFIG["files"]) + st = os.stat(path) + if st.st_mtime != files.get(file): + paths.add(path) + for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]): try: - im = Image.open(path) - im.draft("RGB", self.inference_server_config["image_size"]) - buf = io.BytesIO() - im.resize(self.inference_server_config["image_size"]).convert("RGB").save(buf, format="BMP") - b = buf.getvalue() + b, path = await task + st = os.stat(path) + file = os.path.relpath(path, CONFIG["files"]) + done.add(path) except Exception as e: - print(file, "failed", e) + # print(file, "failed", e) we can't have access to file when we need it, oops continue batch.append((file, st.st_mtime, b)) - if len(batch) % self.inference_server_config["batch"] == self.inference_server_config["batch"] - 1: + if len(batch) == self.inference_server_config["batch"]: await dispatch_batch(batch) batch = [] - if batch: - await dispatch_batch(batch) - - remove_indices = [] - for index, filename in enumerate(self.associated_filenames): - if filename not in files or filename in modified: - remove_indices.append(index) - self.associated_filenames[index] = None - if filename not in files: - await conn.execute("DELETE FROM files WHERE filename = ?", (filename,)) - await conn.commit() - # TODO concurrency - # TODO understand what that comment meant - if remove_indices: - self.faiss_index.remove_ids(numpy.array(remove_indices)) - self.associated_filenames = [ x for x in self.associated_filenames if x is not None ] - - filenames_set = set(self.associated_filenames) - new_data = [] - new_filenames = [] - async with conn.execute("SELECT * FROM files") as csr: - while row := await csr.fetchone(): - filename, modtime, embedding_vector = row - if filename not in filenames_set: - new_data.append(numpy.frombuffer(embedding_vector, dtype="float16")) - new_filenames.append(filename) - new_data = numpy.array(new_data) - self.associated_filenames.extend(new_filenames) - self.faiss_index.add(new_data) - finally: - await conn.close() + failed |= paths - done + if batch: + await dispatch_batch(batch) + + print() + for failed_ in failed: + print(failed_, "failed") + + remove_indices = [] + for index, filename in enumerate(self.associated_filenames): + if filename not in files or filename in modified: + remove_indices.append(index) + self.associated_filenames[index] = None + if filename not in files: + await conn.execute("DELETE FROM files WHERE filename = ?", (filename,)) + await conn.commit() + # TODO concurrency + # TODO understand what that comment meant + if remove_indices: + self.faiss_index.remove_ids(numpy.array(remove_indices)) + self.associated_filenames = [ x for x in self.associated_filenames if x is not None ] + + filenames_set = set(self.associated_filenames) + new_data = [] + new_filenames = [] + async with conn.execute("SELECT * FROM files") as csr: + while row := await csr.fetchone(): + filename, modtime, embedding_vector = row + if filename not in filenames_set: + new_data.append(numpy.frombuffer(embedding_vector, dtype="float16")) + new_filenames.append(filename) + new_data = numpy.array(new_data) + self.associated_filenames.extend(new_filenames) + self.faiss_index.add(new_data) + finally: + await conn.close() app.router.add_routes(routes) @@ -195,7 +215,8 @@ async def main(): site = web.TCPSite(runner, "", CONFIG["port"]) await site.start() -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) -loop.run_until_complete(main()) -loop.run_forever() \ No newline at end of file +if __name__ == "__main__": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(main()) + loop.run_forever() \ No newline at end of file