Skip to content

Commit 0209d91

Browse files
rootfcostaoliveira
root
authored andcommitted
cd /home/fco/redislabs/vector-db-benchmark && git status
cd /home/fco/redislabs/vector-db-benchmark && git add engine/base_client/search.py engine/base_client/client.py chunk up the iterable before starting the processes
1 parent a8d26cd commit 0209d91

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

engine/base_client/search.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from multiprocessing import get_context
44
from typing import Iterable, List, Optional, Tuple
5+
from itertools import islice
56

67
import numpy as np
78
import tqdm
@@ -112,22 +113,31 @@ def search_all(
112113
else:
113114
ctx = get_context(self.get_mp_start_method())
114115

115-
with ctx.Pool(
116-
processes=parallel,
117-
initializer=self.__class__.init_client,
118-
initargs=(
116+
def process_initializer():
117+
"""Initialize each process before starting the search."""
118+
self.__class__.init_client(
119119
self.host,
120120
distance,
121121
self.connection_params,
122122
self.search_params,
123-
),
123+
)
124+
self.setup_search()
125+
126+
# Dynamically chunk the generator
127+
query_chunks = list(chunked_iterable(used_queries, max(1, len(used_queries) // parallel)))
128+
129+
with ctx.Pool(
130+
processes=parallel,
131+
initializer=process_initializer,
124132
) as pool:
125133
if parallel > 10:
126134
time.sleep(15) # Wait for all processes to start
127135
start = time.perf_counter()
128-
precisions, latencies = list(
129-
zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(used_queries)))
136+
results = pool.starmap(
137+
process_chunk,
138+
[(chunk, search_one) for chunk in query_chunks],
130139
)
140+
precisions, latencies = zip(*[result for chunk in results for result in chunk])
131141

132142
total_time = time.perf_counter() - start
133143

@@ -157,3 +167,15 @@ def post_search(self):
157167
@classmethod
158168
def delete_client(cls):
159169
pass
170+
171+
172+
def chunked_iterable(iterable, size):
173+
"""Yield successive chunks of a given size from an iterable."""
174+
it = iter(iterable)
175+
while chunk := list(islice(it, size)):
176+
yield chunk
177+
178+
179+
def process_chunk(chunk, search_one):
180+
"""Process a chunk of queries using the search_one function."""
181+
return [search_one(query) for query in chunk]

0 commit comments

Comments
 (0)