Skip to content

Commit b8b1ea4

Browse files
jucorclaude
andcommitted
Optimize update_votes with vectorized pivot_table (5x speedup)
Replace the row-by-row for-loop in update_votes with a vectorized pivot_table approach. This dramatically speeds up vote loading for large datasets. Performance on bg2050 dataset (1M+ votes, 7.8k participants, 7.7k comments): - Before: 18.5s average, 56k votes/sec - After: 3.5s average, 295k votes/sec - Speedup: 5.3x overall, 16x for the batch update step The optimization: 1. Use pivot_table to reshape long-form votes to wide-form matrix 2. Use DataFrame.where() to merge with existing matrix 3. Use float32 for intermediate matrix to halve memory usage Also adds a benchmark script at polismath/benchmarks/bench_update_votes.py for measuring update_votes performance. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ee24452 commit b8b1ea4

File tree

3 files changed

+123
-18
lines changed

3 files changed

+123
-18
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Benchmark scripts for polismath performance testing."""
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Benchmark script for update_votes performance.
4+
5+
Usage:
6+
cd delphi
7+
../.venv/bin/python -m polismath.benchmarks.bench_update_votes [dataset_name] [--runs N]
8+
9+
Example:
10+
../.venv/bin/python -m polismath.benchmarks.bench_update_votes bg2050 --runs 3
11+
"""
12+
import argparse
13+
import time
14+
import sys
15+
16+
17+
def benchmark_update_votes(dataset_name: str = 'bg2050', runs: int = 3) -> dict:
18+
"""
19+
Benchmark update_votes on a dataset.
20+
21+
Args:
22+
dataset_name: Name of the dataset to benchmark
23+
runs: Number of runs to average
24+
25+
Returns:
26+
Dictionary with benchmark results
27+
"""
28+
from polismath.conversation import Conversation
29+
from polismath.regression.utils import prepare_votes_data
30+
31+
print(f"Loading dataset '{dataset_name}'...")
32+
votes_dict, metadata = prepare_votes_data(dataset_name)
33+
n_votes = len(votes_dict['votes'])
34+
print(f"Loaded {n_votes:,} votes")
35+
print()
36+
37+
times = []
38+
for i in range(runs):
39+
conv = Conversation(dataset_name)
40+
start = time.perf_counter()
41+
conv = conv.update_votes(votes_dict, recompute=False)
42+
elapsed = time.perf_counter() - start
43+
times.append(elapsed)
44+
print(f" Run {i+1}: {elapsed:.2f}s")
45+
46+
avg = sum(times) / len(times)
47+
min_time = min(times)
48+
max_time = max(times)
49+
50+
print()
51+
print(f"Dataset: {dataset_name}")
52+
print(f"Votes: {n_votes:,}")
53+
print(f"Matrix shape: {conv.raw_rating_mat.shape}")
54+
print(f"Average time: {avg:.2f}s")
55+
print(f"Min/Max: {min_time:.2f}s / {max_time:.2f}s")
56+
print(f"Throughput: {n_votes/avg:,.0f} votes/sec")
57+
58+
return {
59+
'dataset': dataset_name,
60+
'n_votes': n_votes,
61+
'shape': conv.raw_rating_mat.shape,
62+
'times': times,
63+
'avg': avg,
64+
'min': min_time,
65+
'max': max_time,
66+
'throughput': n_votes / avg,
67+
}
68+
69+
70+
def main():
71+
parser = argparse.ArgumentParser(description='Benchmark update_votes performance')
72+
parser.add_argument('dataset', nargs='?', default='bg2050',
73+
help='Dataset name (default: bg2050)')
74+
parser.add_argument('--runs', type=int, default=3,
75+
help='Number of benchmark runs (default: 3)')
76+
args = parser.parse_args()
77+
78+
try:
79+
benchmark_update_votes(args.dataset, args.runs)
80+
except Exception as e:
81+
print(f"Error: {e}", file=sys.stderr)
82+
sys.exit(1)
83+
84+
85+
if __name__ == '__main__':
86+
main()

delphi/polismath/conversation/conversation.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -231,25 +231,43 @@ def update_votes(self,
231231

232232
logger.info(f"[{time.time() - start_time:.2f}s] Found {len(new_rows)} new rows and {len(new_cols)} new columns")
233233

234-
# Apply all updates in a single batch operation for better performance
235-
# Honestly, we should probably keep the matrix of votes in long-form,
236-
# and only convert to wide-form when requested.
237-
238-
logger.info(f"[{time.time() - start_time:.2f}s] Applying {len(vote_updates)} votes as batch update...")
234+
# Apply all updates using vectorized pivot_table approach.
235+
# This is much faster than row-by-row iteration because pandas/numpy
236+
# can use optimized C code for the reshape operation.
237+
238+
logger.info(f"[{time.time() - start_time:.2f}s] Applying {len(updates_df)} votes as batch update...")
239239
batch_start = time.time()
240-
# For backward compatibility, sort the rows and columns by label.
241-
result.raw_rating_mat = result.raw_rating_mat.reindex(index=all_rows, columns=all_cols, fill_value=np.nan)
242-
# NOTE: we cannot use .loc(rows, cols) = values with rows,cols,and values being Series
243-
# for example `result.raw_rating_mat.loc[updates_df['row'], updates_df['col']] = updates_df['value'].values`
244-
# because pandas then tries to assign to the Cartesian product of rows and cols, and it gets very messy
245-
# and is definitely *not* what we intended.
246-
# We could convert to integer indices with get_loc, then use .value to use numpy assignment (which does not
247-
# do any cartesian product), but a/ it's less legible, b/ there is *no* guarantee at all that .value is always
248-
# a view and not a copy, so we might end up modifying a copy of the data frame.
249-
# Therefore, for simplicity and readability, sticking to an ugly for loop.
250-
# If you have a better idea, let me know at julien@cornebise.com, I would love to know :)
251-
for idx, row_data in updates_df.iterrows():
252-
result.raw_rating_mat.at[row_data['row'], row_data['col']] = row_data['value']
240+
241+
# Build a wide-form matrix from the long-form updates using pivot_table.
242+
# aggfunc='last' keeps the last vote if any duplicates remain after dedup.
243+
update_matrix = updates_df.pivot_table(
244+
index='row',
245+
columns='col',
246+
values='value',
247+
aggfunc='last'
248+
)
249+
# Use float32 for the intermediate matrix to save memory (~200MB vs
250+
# ~400MB for 8k comments and 8k participants). float32 can exactly
251+
# represent -1, 0, +1 and NaN.
252+
update_matrix = update_matrix.astype('float32')
253+
254+
# Expand the existing matrix to include any new rows/columns.
255+
# fill_value=np.nan ensures new cells start as "no vote".
256+
result.raw_rating_mat = result.raw_rating_mat.reindex(
257+
index=all_rows, columns=all_cols, fill_value=np.nan
258+
)
259+
260+
# Align the update matrix to the same shape (new cells become NaN).
261+
update_matrix = update_matrix.reindex(index=all_rows, columns=all_cols)
262+
263+
# Merge: where update_matrix has a value, use it; otherwise keep original.
264+
# DataFrame.where(cond, other) keeps self where cond is True, uses other where False.
265+
# So: keep raw_rating_mat where update_matrix is NaN, else use update_matrix.
266+
result.raw_rating_mat = result.raw_rating_mat.where(
267+
update_matrix.isna(), # condition: True where update has no value
268+
update_matrix # other: use update value where condition is False
269+
)
270+
253271
logger.info(f"[{time.time() - start_time:.2f}s] Batch update completed in {time.time() - batch_start:.2f}s")
254272

255273
# Update last updated timestamp

0 commit comments

Comments
 (0)