1
1
import functools
2
2
import time
3
- from multiprocessing import get_context
3
+ from multiprocessing import get_context , Barrier , Process , Queue
4
4
from typing import Iterable , List , Optional , Tuple
5
5
import itertools
6
6
@@ -65,6 +65,10 @@ def search_all(
65
65
):
66
66
parallel = self .search_params .get ("parallel" , 1 )
67
67
top = self .search_params .get ("top" , None )
68
+
69
+ # Convert queries to a list to calculate its length
70
+ queries = list (queries ) # This allows us to calculate len(queries)
71
+
68
72
# setup_search may require initialized client
69
73
self .init_client (
70
74
self .host , distance , self .connection_params , self .search_params
@@ -80,31 +84,56 @@ def search_all(
80
84
print (f"Limiting queries to [0:{ MAX_QUERIES - 1 } ]" )
81
85
82
86
if parallel == 1 :
87
+ # Single-threaded execution
83
88
start = time .perf_counter ()
84
- precisions , latencies = list (
85
- zip (* [search_one (query ) for query in tqdm .tqdm (used_queries )])
86
- )
89
+
90
+ results = [search_one (query ) for query in tqdm .tqdm (queries )]
91
+ total_time = time .perf_counter () - start
92
+
87
93
else :
88
- ctx = get_context (self .get_mp_start_method ())
94
+ # Dynamically calculate chunk size
95
+ chunk_size = max (1 , len (queries ) // parallel )
96
+ query_chunks = list (chunked_iterable (queries , chunk_size ))
89
97
90
- with ctx .Pool (
91
- processes = parallel ,
92
- initializer = self .__class__ .init_client ,
93
- initargs = (
98
+ # Function to be executed by each worker process
99
+ def worker_function (chunk , result_queue ):
100
+ self .__class__ .init_client (
94
101
self .host ,
95
102
distance ,
96
103
self .connection_params ,
97
104
self .search_params ,
98
- ),
99
- ) as pool :
100
- if parallel > 10 :
101
- time .sleep (15 ) # Wait for all processes to start
102
- start = time .perf_counter ()
103
- precisions , latencies = list (
104
- zip (* pool .imap_unordered (search_one , iterable = tqdm .tqdm (used_queries )))
105
105
)
106
+ self .setup_search ()
107
+ results = process_chunk (chunk , search_one )
108
+ result_queue .put (results )
109
+
110
+ # Create a queue to collect results
111
+ result_queue = Queue ()
112
+
113
+ # Create and start worker processes
114
+ processes = []
115
+ for chunk in query_chunks :
116
+ process = Process (target = worker_function , args = (chunk , result_queue ))
117
+ processes .append (process )
118
+ process .start ()
119
+
120
+ # Start measuring time for the critical work
121
+ start = time .perf_counter ()
106
122
107
- total_time = time .perf_counter () - start
123
+ # Collect results from all worker processes
124
+ results = []
125
+ for _ in processes :
126
+ results .extend (result_queue .get ())
127
+
128
+ # Wait for all worker processes to finish
129
+ for process in processes :
130
+ process .join ()
131
+
132
+ # Stop measuring time for the critical work
133
+ total_time = time .perf_counter () - start
134
+
135
+ # Extract precisions and latencies (outside the timed section)
136
+ precisions , latencies = zip (* results )
108
137
109
138
self .__class__ .delete_client ()
110
139
@@ -132,3 +161,20 @@ def post_search(self):
132
161
@classmethod
133
162
def delete_client (cls ):
134
163
pass
164
+
165
+
166
+ def chunked_iterable (iterable , size ):
167
+ """Yield successive chunks of a given size from an iterable."""
168
+ it = iter (iterable )
169
+ while chunk := list (itertools .islice (it , size )):
170
+ yield chunk
171
+
172
+
173
+ def process_chunk (chunk , search_one ):
174
+ """Process a chunk of queries using the search_one function."""
175
+ return [search_one (query ) for query in chunk ]
176
+
177
+
178
+ def process_chunk_wrapper (chunk , search_one ):
179
+ """Wrapper to process a chunk of queries."""
180
+ return process_chunk (chunk , search_one )
0 commit comments