Skip to content

Commit 85a6bc7

Browse files
Merge pull request #16 from mpozniak95/fix-sync
Fixing scalability issues with vector db benchmark
2 parents ba175b1 + 2c592a0 commit 85a6bc7

File tree

3 files changed

+86
-20
lines changed

3 files changed

+86
-20
lines changed

datasets/datasets.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,5 +1203,22 @@
12031203
"type": "tar",
12041204
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_keywords_1m_vocab_10_no_filters.tgz",
12051205
"path": "random-100-match-kw-small-vocab/random_keywords_1m_vocab_10_no_filters"
1206+
},
1207+
{
1208+
"name": "laion-img-emb-512-1M-cosine",
1209+
"vector_size": 512,
1210+
"distance": "cosine",
1211+
"type": "h5",
1212+
"path": "laion-img-emb-512/laion-img-emb-512-1M-cosine.hdf5",
1213+
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion400m/laion-img-emb-512-100M-cosine.hdf5"
1214+
},
1215+
{
1216+
"name": "laion-img-emb-512-1M-100ktrain-cosine",
1217+
"vector_size": 512,
1218+
"distance": "cosine",
1219+
"type": "h5",
1220+
"path": "laion-img-emb-512/laion-img-emb-512-1M-100ktrain-cosine.hdf5",
1221+
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion400m/laion-img-emb-512-100M-cosine.hdf5"
12061222
}
1223+
12071224
]

engine/base_client/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def save_search_results(
4040
):
4141
now = datetime.now()
4242
timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")
43+
pid = os.getpid() # Get the current process ID
4344
experiments_file = (
44-
f"{self.name}-{dataset_name}-search-{search_id}-{timestamp}.json"
45+
f"{self.name}-{dataset_name}-search-{search_id}-{pid}-{timestamp}.json"
4546
)
4647
result_path = RESULTS_DIR / experiments_file
4748
with open(result_path, "w") as out:
@@ -97,7 +98,8 @@ def run_experiment(
9798
reader = dataset.get_reader(execution_params.get("normalize", False))
9899

99100
if skip_if_exists:
100-
glob_pattern = f"{self.name}-{dataset.config.name}-search-*-*.json"
101+
pid = os.getpid() # Get the current process ID
102+
glob_pattern = f"{self.name}-{dataset.config.name}-search-{pid}-*-*.json"
101103
existing_results = list(RESULTS_DIR.glob(glob_pattern))
102104
if len(existing_results) == len(self.searchers):
103105
print(
@@ -135,8 +137,9 @@ def run_experiment(
135137
print("Experiment stage: Search")
136138
for search_id, searcher in enumerate(self.searchers):
137139
if skip_if_exists:
140+
pid = os.getpid() # Get the current process ID
138141
glob_pattern = (
139-
f"{self.name}-{dataset.config.name}-search-{search_id}-*.json"
142+
f"{self.name}-{dataset.config.name}-search-{search_id}-{pid}-*.json"
140143
)
141144
existing_results = list(RESULTS_DIR.glob(glob_pattern))
142145
print("Pattern", glob_pattern, "Results:", existing_results)

engine/base_client/search.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
import time
3-
from multiprocessing import get_context
3+
from multiprocessing import get_context, Barrier, Process, Queue
44
from typing import Iterable, List, Optional, Tuple
55
import itertools
66

@@ -65,6 +65,10 @@ def search_all(
6565
):
6666
parallel = self.search_params.get("parallel", 1)
6767
top = self.search_params.get("top", None)
68+
69+
# Convert queries to a list to calculate its length
70+
queries = list(queries) # This allows us to calculate len(queries)
71+
6872
# setup_search may require initialized client
6973
self.init_client(
7074
self.host, distance, self.connection_params, self.search_params
@@ -80,31 +84,56 @@ def search_all(
8084
print(f"Limiting queries to [0:{MAX_QUERIES-1}]")
8185

8286
if parallel == 1:
87+
# Single-threaded execution
8388
start = time.perf_counter()
84-
precisions, latencies = list(
85-
zip(*[search_one(query) for query in tqdm.tqdm(used_queries)])
86-
)
89+
90+
results = [search_one(query) for query in tqdm.tqdm(queries)]
91+
total_time = time.perf_counter() - start
92+
8793
else:
88-
ctx = get_context(self.get_mp_start_method())
94+
# Dynamically calculate chunk size
95+
chunk_size = max(1, len(queries) // parallel)
96+
query_chunks = list(chunked_iterable(queries, chunk_size))
8997

90-
with ctx.Pool(
91-
processes=parallel,
92-
initializer=self.__class__.init_client,
93-
initargs=(
98+
# Function to be executed by each worker process
99+
def worker_function(chunk, result_queue):
100+
self.__class__.init_client(
94101
self.host,
95102
distance,
96103
self.connection_params,
97104
self.search_params,
98-
),
99-
) as pool:
100-
if parallel > 10:
101-
time.sleep(15) # Wait for all processes to start
102-
start = time.perf_counter()
103-
precisions, latencies = list(
104-
zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(used_queries)))
105105
)
106+
self.setup_search()
107+
results = process_chunk(chunk, search_one)
108+
result_queue.put(results)
109+
110+
# Create a queue to collect results
111+
result_queue = Queue()
112+
113+
# Create and start worker processes
114+
processes = []
115+
for chunk in query_chunks:
116+
process = Process(target=worker_function, args=(chunk, result_queue))
117+
processes.append(process)
118+
process.start()
119+
120+
# Start measuring time for the critical work
121+
start = time.perf_counter()
106122

107-
total_time = time.perf_counter() - start
123+
# Collect results from all worker processes
124+
results = []
125+
for _ in processes:
126+
results.extend(result_queue.get())
127+
128+
# Wait for all worker processes to finish
129+
for process in processes:
130+
process.join()
131+
132+
# Stop measuring time for the critical work
133+
total_time = time.perf_counter() - start
134+
135+
# Extract precisions and latencies (outside the timed section)
136+
precisions, latencies = zip(*results)
108137

109138
self.__class__.delete_client()
110139

@@ -132,3 +161,20 @@ def post_search(self):
132161
@classmethod
133162
def delete_client(cls):
134163
pass
164+
165+
166+
def chunked_iterable(iterable, size):
167+
"""Yield successive chunks of a given size from an iterable."""
168+
it = iter(iterable)
169+
while chunk := list(itertools.islice(it, size)):
170+
yield chunk
171+
172+
173+
def process_chunk(chunk, search_one):
174+
"""Process a chunk of queries using the search_one function."""
175+
return [search_one(query) for query in chunk]
176+
177+
178+
def process_chunk_wrapper(chunk, search_one):
179+
"""Wrapper to process a chunk of queries."""
180+
return process_chunk(chunk, search_one)

0 commit comments

Comments
 (0)