Skip to content

Commit

Permalink
Move plotting to plot_distribution function
Browse files Browse the repository at this point in the history
  • Loading branch information
endolith committed Aug 5, 2024
1 parent d175b26 commit 2f8ec55
Showing 1 changed file with 22 additions and 26 deletions.
48 changes: 22 additions & 26 deletions examples/distributions_by_method_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,27 @@ def simulate_batch(n_cands):

# %% Plotting


def plot_distribution(ax, data, title, max_lim):
heatmap, xedges, yedges = np.histogram2d(data[:, 0], data[:, 1], bins=50,
range=[[-max_lim, max_lim],
[-max_lim, max_lim]])
extent = [-max_lim, max_lim, -max_lim, max_lim]
ax.imshow(heatmap.T, cmap="Blues", origin='lower',
aspect='auto', extent=extent)
ax.set_xlim([-max_lim, max_lim])
ax.set_ylim([-max_lim, max_lim])
ax.set_aspect('equal') # Set the aspect ratio to be equal (square)
ax.set_title(title, loc='left')
ax.tick_params(left=False, bottom=False) # Remove tick marks
ax.set_xticks([]) # Remove x-axis ticks
ax.set_yticks([]) # Remove y-axis ticks
ax.set_xlabel("") # Remove x-axis label
ax.set_ylabel("") # Remove y-axis label
for spine in ax.spines.values():
spine.set_visible(True)


title = f'{human_format(n_elections)} 2D elections, '
title += f'{human_format(n_voters)} voters, '
title += f'{human_format(n_cands)} candidates'
Expand All @@ -162,32 +183,7 @@ def simulate_batch(n_cands):
max_lim = 1.5

for n, method in enumerate(winners.keys()):
coordinates = winners[method]
# Create a 2D histogram with specified range
heatmap, xedges, yedges = np.histogram2d(coordinates[:, 0],
coordinates[:, 1], bins=50,
range=[[-max_lim, max_lim],
[-max_lim, max_lim]])

# Calculate extent from xedges and yedges
extent = [-max_lim, max_lim, -max_lim, max_lim]

ax[n].imshow(heatmap.T, cmap="Blues", origin='lower',
aspect='auto', extent=extent)

ax[n].set_xlim([-max_lim, max_lim])
ax[n].set_ylim([-max_lim, max_lim])
ax[n].set_aspect('equal') # Set the aspect ratio to be equal (square)
ax[n].set_title(method, loc='left')
ax[n].tick_params(left=False, bottom=False) # Remove tick marks
ax[n].set_xticks([]) # Remove x-axis ticks
ax[n].set_yticks([]) # Remove y-axis ticks
ax[n].set_xlabel("") # Remove x-axis label
ax[n].set_ylabel("") # Remove y-axis label

# Add borders to the plots
for spine in ax[n].spines.values():
spine.set_visible(True)
plot_distribution(ax[n], winners[method], method, max_lim)

# Add standard deviation text in the lower right corner
std = winners_stats[method][1]
Expand Down

0 comments on commit 2f8ec55

Please sign in to comment.