Skip to content

Commit 25ca5ac

Browse files
Merge pull request #18 from redis-performance/restore-performance-optimizations
Restore performance optimizations
2 parents a8d26cd + 343e906 commit 25ca5ac

File tree

3 files changed

+156
-33
lines changed

3 files changed

+156
-33
lines changed

engine/base_client/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def save_search_results(
4040
):
4141
now = datetime.now()
4242
timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")
43+
pid = os.getpid() # Get the current process ID
4344
experiments_file = (
44-
f"{self.name}-{dataset_name}-search-{search_id}-{timestamp}.json"
45+
f"{self.name}-{dataset_name}-search-{search_id}-{pid}-{timestamp}.json"
4546
)
4647
result_path = RESULTS_DIR / experiments_file
4748
with open(result_path, "w") as out:
@@ -99,7 +100,8 @@ def run_experiment(
99100
reader = dataset.get_reader(execution_params.get("normalize", False))
100101

101102
if skip_if_exists:
102-
glob_pattern = f"{self.name}-{dataset.config.name}-search-*-*.json"
103+
pid = os.getpid() # Get the current process ID
104+
glob_pattern = f"{self.name}-{dataset.config.name}-search-*-{pid}-*.json"
103105
existing_results = list(RESULTS_DIR.glob(glob_pattern))
104106
if len(existing_results) == len(self.searchers):
105107
print(
@@ -137,8 +139,9 @@ def run_experiment(
137139
print("Experiment stage: Search")
138140
for search_id, searcher in enumerate(self.searchers):
139141
if skip_if_exists:
142+
pid = os.getpid() # Get the current process ID
140143
glob_pattern = (
141-
f"{self.name}-{dataset.config.name}-search-{search_id}-*.json"
144+
f"{self.name}-{dataset.config.name}-search-{search_id}-{pid}-*.json"
142145
)
143146
existing_results = list(RESULTS_DIR.glob(glob_pattern))
144147
print("Pattern", glob_pattern, "Results:", existing_results)

engine/base_client/search.py

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import functools
22
import time
3-
from multiprocessing import get_context
3+
from multiprocessing import Process, Queue
44
from typing import Iterable, List, Optional, Tuple
5+
from itertools import islice
56

67
import numpy as np
78
import tqdm
@@ -83,53 +84,118 @@ def search_all(
8384

8485
# Handle num_queries parameter
8586
if num_queries > 0:
86-
# If we need more queries than available, cycle through the list
87+
# If we need more queries than available, use a cycling generator
8788
if num_queries > len(queries_list) and len(queries_list) > 0:
8889
print(f"Requested {num_queries} queries but only {len(queries_list)} are available.")
89-
print(f"Extending queries by cycling through the available ones.")
90-
# Calculate how many complete cycles and remaining items we need
91-
complete_cycles = num_queries // len(queries_list)
92-
remaining = num_queries % len(queries_list)
93-
94-
# Create the extended list
95-
extended_queries = []
96-
for _ in range(complete_cycles):
97-
extended_queries.extend(queries_list)
98-
extended_queries.extend(queries_list[:remaining])
99-
100-
used_queries = extended_queries
90+
print(f"Using a cycling generator to efficiently process queries.")
91+
92+
# Create a cycling generator function
93+
def cycling_query_generator(queries, total_count):
94+
"""Generate queries by cycling through the available ones."""
95+
count = 0
96+
while count < total_count:
97+
for query in queries:
98+
if count < total_count:
99+
yield query
100+
count += 1
101+
else:
102+
break
103+
104+
# Use the generator instead of creating a full list
105+
used_queries = cycling_query_generator(queries_list, num_queries)
106+
# We need to know the total count for the progress bar
107+
total_query_count = num_queries
101108
else:
102109
used_queries = queries_list[:num_queries]
110+
total_query_count = len(used_queries)
103111
print(f"Using {num_queries} queries")
104112
else:
105113
used_queries = queries_list
114+
total_query_count = len(used_queries)
106115

107116
if parallel == 1:
117+
# Single-threaded execution
108118
start = time.perf_counter()
109-
precisions, latencies = list(
110-
zip(*[search_one(query) for query in tqdm.tqdm(used_queries)])
111-
)
119+
120+
# Create a progress bar with the correct total
121+
pbar = tqdm.tqdm(total=total_query_count, desc="Processing queries", unit="queries")
122+
123+
# Process queries with progress updates
124+
results = []
125+
for query in used_queries:
126+
results.append(search_one(query))
127+
pbar.update(1)
128+
129+
# Close the progress bar
130+
pbar.close()
131+
132+
total_time = time.perf_counter() - start
112133
else:
113-
ctx = get_context(self.get_mp_start_method())
134+
# Dynamically calculate chunk size based on total_query_count
135+
chunk_size = max(1, total_query_count // parallel)
136+
137+
# If used_queries is a generator, we need to handle it differently
138+
if hasattr(used_queries, '__next__'):
139+
# For generators, we'll create chunks on-the-fly
140+
query_chunks = []
141+
remaining = total_query_count
142+
while remaining > 0:
143+
current_chunk_size = min(chunk_size, remaining)
144+
chunk = [next(used_queries) for _ in range(current_chunk_size)]
145+
query_chunks.append(chunk)
146+
remaining -= current_chunk_size
147+
else:
148+
# For lists, we can use the chunked_iterable function
149+
query_chunks = list(chunked_iterable(used_queries, chunk_size))
114150

115-
with ctx.Pool(
116-
processes=parallel,
117-
initializer=self.__class__.init_client,
118-
initargs=(
151+
# Function to be executed by each worker process
152+
def worker_function(chunk, result_queue):
153+
self.__class__.init_client(
119154
self.host,
120155
distance,
121156
self.connection_params,
122157
self.search_params,
123-
),
124-
) as pool:
125-
if parallel > 10:
126-
time.sleep(15) # Wait for all processes to start
127-
start = time.perf_counter()
128-
precisions, latencies = list(
129-
zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(used_queries)))
130158
)
159+
self.setup_search()
160+
results = process_chunk(chunk, search_one)
161+
result_queue.put(results)
162+
163+
# Create a queue to collect results
164+
result_queue = Queue()
165+
166+
# Create and start worker processes
167+
processes = []
168+
for chunk in query_chunks:
169+
process = Process(target=worker_function, args=(chunk, result_queue))
170+
processes.append(process)
171+
process.start()
172+
173+
# Start measuring time for the critical work
174+
start = time.perf_counter()
175+
176+
# Create a progress bar for the total number of queries
177+
pbar = tqdm.tqdm(total=total_query_count, desc="Processing queries", unit="queries")
131178

132-
total_time = time.perf_counter() - start
179+
# Collect results from all worker processes
180+
results = []
181+
for _ in processes:
182+
chunk_results = result_queue.get()
183+
results.extend(chunk_results)
184+
# Update the progress bar with the number of processed queries in this chunk
185+
pbar.update(len(chunk_results))
186+
187+
# Close the progress bar
188+
pbar.close()
189+
190+
# Wait for all worker processes to finish
191+
for process in processes:
192+
process.join()
193+
194+
# Stop measuring time for the critical work
195+
total_time = time.perf_counter() - start
196+
197+
# Extract precisions and latencies (outside the timed section)
198+
precisions, latencies = zip(*results)
133199

134200
self.__class__.delete_client()
135201

@@ -157,3 +223,21 @@ def post_search(self):
157223
@classmethod
158224
def delete_client(cls):
159225
pass
226+
227+
228+
def chunked_iterable(iterable, size):
229+
"""Yield successive chunks of a given size from an iterable."""
230+
it = iter(iterable)
231+
while chunk := list(islice(it, size)):
232+
yield chunk
233+
234+
235+
def process_chunk(chunk, search_one):
236+
"""Process a chunk of queries using the search_one function."""
237+
# No progress bar in worker processes to avoid cluttering the output
238+
return [search_one(query) for query in chunk]
239+
240+
241+
def process_chunk_wrapper(chunk, search_one):
242+
"""Wrapper to process a chunk of queries."""
243+
return process_chunk(chunk, search_one)

test_multiprocessing.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from engine.base_client.search import BaseSearcher
2+
from dataset_reader.base_reader import Query
3+
import time
4+
5+
class TestSearcher(BaseSearcher):
6+
@classmethod
7+
def init_client(cls, host, distance, connection_params, search_params):
8+
pass
9+
10+
@classmethod
11+
def search_one(cls, vector, meta_conditions, top):
12+
return []
13+
14+
@classmethod
15+
def _search_one(cls, query, top=None):
16+
# Add a small delay to simulate real work
17+
time.sleep(0.001)
18+
return 1.0, 0.1
19+
20+
def setup_search(self):
21+
pass
22+
23+
# Create a small set of test queries
24+
queries = [Query(vector=[0.1]*10, meta_conditions=None, expected_result=None) for _ in range(10)]
25+
26+
# Create a searcher with parallel=10
27+
searcher = TestSearcher('localhost', {}, {'parallel': 10})
28+
29+
# Run the search_all method with a large num_queries parameter
30+
start = time.perf_counter()
31+
results = searcher.search_all('cosine', queries, num_queries=1000)
32+
total_time = time.perf_counter() - start
33+
34+
print(f'Number of queries: {len(results["latencies"])}')
35+
print(f'Total time: {total_time:.6f} seconds')
36+
print(f'Throughput: {results["rps"]:.2f} queries/sec')

0 commit comments

Comments
 (0)