-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
Here is an improved version of the receptive fields function in the receptive fields chapter:
def get_rf_vectorized(spike_times, rf_stim_table, xs=None, ys=None, response_window=0.2):
"""
Calculate receptive field using vectorized operations for improved performance.
Parameters:
-----------
spike_times : array-like
Spike times for the unit
rf_stim_table : pandas.DataFrame
Stimulus table containing gabor presentations
xs : array-like, optional
X-coordinates of the receptive field. If not provided, uses rf_stim_table.x_positions
ys : array-like, optional
Y-coordinates of the receptive field. If not provided, uses rf_stim_table.y_positions
response_window : float, optional
Time window after stimulus onset to count spikes (default: 0.2 seconds)
Returns:
--------
unit_rf : numpy.ndarray
2D array containing spike counts for each position in the receptive field
"""
if xs is None:
xs = np.sort(rf_stim_table.x_position.unique())
if ys is None:
ys = np.sort(rf_stim_table.y_position.unique())
# Initialize the receptive field array
unit_rf = np.nans((len(ys), len(xs)))
# Pre-compute stimulus presentations by position
position_stim_times = {}
for xi, x in enumerate(xs):
for yi, y in enumerate(ys):
position_key = (x, y)
position_stim_times[position_key] = rf_stim_table[
(rf_stim_table.x_position == x) &
(rf_stim_table.y_position == y)
].start_time.values
# Process each position
for xi, x in enumerate(xs):
for yi, y in enumerate(ys):
position_key = (x, y)
stim_times = position_stim_times[position_key]
if len(stim_times) == 0:
continue
# Create arrays of start and end times for the response windows
start_times = stim_times
end_times = stim_times + response_window
# Find indices for all spikes within the response windows in one operation
# This is much faster than searching for each stimulus individually
all_indices = np.searchsorted(spike_times, np.column_stack((start_times, end_times)))
# Calculate spike counts for all presentations at once
spike_counts = all_indices[:, 1] - all_indices[:, 0]
# Store the total spike count for this position
unit_rf[yi, xi] = np.mean(spike_counts)
return unit_rfImprovements:
- Improved Documentation: Includes proper docstring with parameter descriptions and return value information, making the code more maintainable and easier to understand.
- Efficient Data Structure: Pre-computes stimulus presentations by position into a dictionary, avoiding repeated filtering of the stimulus table for the same coordinates.
- Vectorized Spike Counting: Uses NumPy's vectorized operations with np.column_stack and np.searchsorted to process all stimulus presentations for a given position simultaneously rather than one at a time.
- Performance Optimization: Processes all spike times in one batch operation per position rather than looping through each stimulus presentation individually, significantly reducing computational overhead.
- Skip Empty Positions: Adds a check for empty stimulus time arrays, skipping unnecessary processing when no stimulus was presented at a given position.
- Parameterized Response Window: Makes the response window duration a configurable parameter, allowing for easy adjustment without modifying the function code.
- Better Memory Usage: More efficient memory allocation by initializing the result array once and updating it in place, rather than repeatedly creating new arrays.
- Default to using x and y position values in the rf_stim_table
- Uses mean across trials rather than sum across trials
Metadata
Metadata
Assignees
Labels
No labels