|
2 | 2 | import time
|
3 | 3 | from multiprocessing import get_context
|
4 | 4 | from typing import Iterable, List, Optional, Tuple
|
| 5 | +from itertools import islice |
5 | 6 |
|
6 | 7 | import numpy as np
|
7 | 8 | import tqdm
|
@@ -112,22 +113,31 @@ def search_all(
|
112 | 113 | else:
|
113 | 114 | ctx = get_context(self.get_mp_start_method())
|
114 | 115 |
|
115 |
| - with ctx.Pool( |
116 |
| - processes=parallel, |
117 |
| - initializer=self.__class__.init_client, |
118 |
| - initargs=( |
| 116 | + def process_initializer(): |
| 117 | + """Initialize each process before starting the search.""" |
| 118 | + self.__class__.init_client( |
119 | 119 | self.host,
|
120 | 120 | distance,
|
121 | 121 | self.connection_params,
|
122 | 122 | self.search_params,
|
123 |
| - ), |
| 123 | + ) |
| 124 | + self.setup_search() |
| 125 | + |
| 126 | + # Dynamically chunk the generator |
| 127 | + query_chunks = list(chunked_iterable(used_queries, max(1, len(used_queries) // parallel))) |
| 128 | + |
| 129 | + with ctx.Pool( |
| 130 | + processes=parallel, |
| 131 | + initializer=process_initializer, |
124 | 132 | ) as pool:
|
125 | 133 | if parallel > 10:
|
126 | 134 | time.sleep(15) # Wait for all processes to start
|
127 | 135 | start = time.perf_counter()
|
128 |
| - precisions, latencies = list( |
129 |
| - zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(used_queries))) |
| 136 | + results = pool.starmap( |
| 137 | + process_chunk, |
| 138 | + [(chunk, search_one) for chunk in query_chunks], |
130 | 139 | )
|
| 140 | + precisions, latencies = zip(*[result for chunk in results for result in chunk]) |
131 | 141 |
|
132 | 142 | total_time = time.perf_counter() - start
|
133 | 143 |
|
@@ -157,3 +167,15 @@ def post_search(self):
|
157 | 167 | @classmethod
|
158 | 168 | def delete_client(cls):
|
159 | 169 | pass
|
| 170 | + |
| 171 | + |
| 172 | +def chunked_iterable(iterable, size): |
| 173 | + """Yield successive chunks of a given size from an iterable.""" |
| 174 | + it = iter(iterable) |
| 175 | + while chunk := list(islice(it, size)): |
| 176 | + yield chunk |
| 177 | + |
| 178 | + |
| 179 | +def process_chunk(chunk, search_one): |
| 180 | + """Process a chunk of queries using the search_one function.""" |
| 181 | + return [search_one(query) for query in chunk] |
0 commit comments