Skip to content

Commit 534de8c

Browse files
Restore performance optimizations from PR #16 (85a6bc7)
1 parent 0209d91 commit 534de8c

File tree

2 files changed

+48
-26
lines changed

2 files changed

+48
-26
lines changed

engine/base_client/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def save_search_results(
4040
):
4141
now = datetime.now()
4242
timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")
43+
pid = os.getpid() # Get the current process ID
4344
experiments_file = (
44-
f"{self.name}-{dataset_name}-search-{search_id}-{timestamp}.json"
45+
f"{self.name}-{dataset_name}-search-{search_id}-{pid}-{timestamp}.json"
4546
)
4647
result_path = RESULTS_DIR / experiments_file
4748
with open(result_path, "w") as out:
@@ -99,7 +100,8 @@ def run_experiment(
99100
reader = dataset.get_reader(execution_params.get("normalize", False))
100101

101102
if skip_if_exists:
102-
glob_pattern = f"{self.name}-{dataset.config.name}-search-*-*.json"
103+
pid = os.getpid() # Get the current process ID
104+
glob_pattern = f"{self.name}-{dataset.config.name}-search-*-{pid}-*.json"
103105
existing_results = list(RESULTS_DIR.glob(glob_pattern))
104106
if len(existing_results) == len(self.searchers):
105107
print(
@@ -137,8 +139,9 @@ def run_experiment(
137139
print("Experiment stage: Search")
138140
for search_id, searcher in enumerate(self.searchers):
139141
if skip_if_exists:
142+
pid = os.getpid() # Get the current process ID
140143
glob_pattern = (
141-
f"{self.name}-{dataset.config.name}-search-{search_id}-*.json"
144+
f"{self.name}-{dataset.config.name}-search-{search_id}-{pid}-*.json"
142145
)
143146
existing_results = list(RESULTS_DIR.glob(glob_pattern))
144147
print("Pattern", glob_pattern, "Results:", existing_results)

engine/base_client/search.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
import time
3-
from multiprocessing import get_context
3+
from multiprocessing import Process, Queue
44
from typing import Iterable, List, Optional, Tuple
55
from itertools import islice
66

@@ -106,40 +106,54 @@ def search_all(
106106
used_queries = queries_list
107107

108108
if parallel == 1:
109+
# Single-threaded execution
109110
start = time.perf_counter()
110-
precisions, latencies = list(
111-
zip(*[search_one(query) for query in tqdm.tqdm(used_queries)])
112-
)
111+
results = [search_one(query) for query in tqdm.tqdm(used_queries)]
112+
total_time = time.perf_counter() - start
113113
else:
114-
ctx = get_context(self.get_mp_start_method())
114+
# Dynamically calculate chunk size
115+
chunk_size = max(1, len(used_queries) // parallel)
116+
query_chunks = list(chunked_iterable(used_queries, chunk_size))
115117

116-
def process_initializer():
117-
"""Initialize each process before starting the search."""
118+
# Function to be executed by each worker process
119+
def worker_function(chunk, result_queue):
118120
self.__class__.init_client(
119121
self.host,
120122
distance,
121123
self.connection_params,
122124
self.search_params,
123125
)
124126
self.setup_search()
127+
results = process_chunk(chunk, search_one)
128+
result_queue.put(results)
125129

126-
# Dynamically chunk the generator
127-
query_chunks = list(chunked_iterable(used_queries, max(1, len(used_queries) // parallel)))
128-
129-
with ctx.Pool(
130-
processes=parallel,
131-
initializer=process_initializer,
132-
) as pool:
133-
if parallel > 10:
134-
time.sleep(15) # Wait for all processes to start
135-
start = time.perf_counter()
136-
results = pool.starmap(
137-
process_chunk,
138-
[(chunk, search_one) for chunk in query_chunks],
139-
)
140-
precisions, latencies = zip(*[result for chunk in results for result in chunk])
130+
# Create a queue to collect results
131+
result_queue = Queue()
132+
133+
# Create and start worker processes
134+
processes = []
135+
for chunk in query_chunks:
136+
process = Process(target=worker_function, args=(chunk, result_queue))
137+
processes.append(process)
138+
process.start()
139+
140+
# Start measuring time for the critical work
141+
start = time.perf_counter()
141142

142-
total_time = time.perf_counter() - start
143+
# Collect results from all worker processes
144+
results = []
145+
for _ in processes:
146+
results.extend(result_queue.get())
147+
148+
# Wait for all worker processes to finish
149+
for process in processes:
150+
process.join()
151+
152+
# Stop measuring time for the critical work
153+
total_time = time.perf_counter() - start
154+
155+
# Extract precisions and latencies (outside the timed section)
156+
precisions, latencies = zip(*results)
143157

144158
self.__class__.delete_client()
145159

@@ -179,3 +193,8 @@ def chunked_iterable(iterable, size):
179193
def process_chunk(chunk, search_one):
180194
"""Process a chunk of queries using the search_one function."""
181195
return [search_one(query) for query in chunk]
196+
197+
198+
def process_chunk_wrapper(chunk, search_one):
199+
"""Wrapper to process a chunk of queries."""
200+
return process_chunk(chunk, search_one)

0 commit comments

Comments
 (0)