Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 60 additions & 31 deletions pynest/nest/raster_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _from_memory(detec):
return ev["times"], ev["senders"]


def _make_plot(ts, ts1, node_ids, neurons, hist=True, hist_binwidth=5.0, grayscale=False, title=None, xlabel=None):
def _make_plot(ts, ts1, node_ids, neurons, hist=True, hist_binwidth=5.0, grayscale=False, title=None, xlabel=None, ax=None):
"""Generic plotting routine.

Constructs a raster plot along with an optional histogram (common part in
Expand All @@ -234,55 +234,84 @@ def _make_plot(ts, ts1, node_ids, neurons, hist=True, hist_binwidth=5.0, graysca
Plot title
xlabel : str, optional
Label for x-axis
ax : matplotlib.axes.Axes, optional
The axes object to draw the plot on. If None, a new figure
and axes will be created.
"""
import matplotlib.pyplot as plt

plt.figure()
# --- Axis Management ---
# If no axis is provided, create a new figure and axes.
# This block handles the creation of axes for both raster and histogram.
if ax is None:
import matplotlib.pyplot as plt

fig = plt.figure()
if hist:
# Manually define positions for raster plot and histogram
ax_raster = fig.add_axes([0.1, 0.32, 0.85, 0.6])
ax_hist = fig.add_axes([0.1, 0.1, 0.85, 0.2], sharex=ax_raster)
# Hide x-tick labels on the raster plot to avoid overlap
plt.setp(ax_raster.get_xticklabels(), visible=False)
else:
ax_raster = fig.add_subplot(111)
ax_hist = None
# If an axis is provided, use it for the raster plot.
# The histogram will not be plotted in this case.
else:
ax_raster = ax
ax_hist = None
if hist:
import warnings
warnings.warn("Histogram is disabled when an external axis is provided. Set hist=False to silence this warning.")

# --- Color settings ---
if grayscale:
color_marker = ".k"
color_bar = "gray"
else:
color_marker = "."
color_bar = "blue"

color_edge = "black"

# --- Label settings ---
if xlabel is None:
xlabel = "Time (ms)"

ylabel = "Neuron ID"

if hist:
ax1 = plt.axes([0.1, 0.3, 0.85, 0.6])
plotid = plt.plot(ts1, node_ids, color_marker)
plt.ylabel(ylabel)
plt.xticks([])
xlim = plt.xlim()
# --- Plotting ---
# Raster plot
plotid = ax_raster.plot(ts1, node_ids, color_marker)
ax_raster.set_ylabel(ylabel)

plt.axes([0.1, 0.1, 0.85, 0.17])
# Set title on the main raster plot
if title is None:
ax_raster.set_title("Raster plot")
else:
ax_raster.set_title(title)

# Histogram
if hist and ax_hist is not None:
t_bins = numpy.arange(numpy.amin(ts), numpy.amax(ts), float(hist_binwidth))
n, _ = _histogram(ts, bins=t_bins)
num_neurons = len(numpy.unique(neurons))
heights = 1000 * n / (hist_binwidth * num_neurons)

plt.bar(t_bins, heights, width=hist_binwidth, color=color_bar, edgecolor=color_edge)
plt.yticks([int(x) for x in numpy.linspace(0.0, int(max(heights) * 1.1) + 5, 4)])
plt.ylabel("Rate (spks/s)")
plt.xlabel(xlabel)
plt.xlim(xlim)
plt.axes(ax1)
else:
plotid = plt.plot(ts1, node_ids, color_marker)
plt.xlabel(xlabel)
plt.ylabel(ylabel)

if title is None:
plt.title("Raster plot")

# Avoid division by zero if no neurons are provided
if num_neurons > 0:
heights = 1000 * n / (hist_binwidth * num_neurons)
else:
heights = numpy.zeros_like(n)

# The number of bins is one less than the number of bin edges
ax_hist.bar(t_bins[:-1], heights, width=hist_binwidth, color=color_bar, edgecolor=color_edge, align='edge')

if heights.any() and max(heights) > 0:
ax_hist.set_yticks([int(x) for x in numpy.linspace(0.0, int(max(heights) * 1.1) + 5, 4)])

ax_hist.set_ylabel("Rate (spks/s)")
ax_hist.set_xlabel(xlabel)
ax_raster.set_xlim(ax_hist.get_xlim()) # Ensure x-limits match
else:
plt.title(title)

plt.draw()
# If no histogram, set the x-label on the raster plot itself
ax_raster.set_xlabel(xlabel)

return plotid

Expand Down