Skip to content

Commit

Permalink
Load spike data only on demand and when not already cached
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliaSprenger committed Jun 28, 2023
1 parent 2bc830b commit 4eb5fc1
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions neo/rawio/plexon2rawio/plexon2rawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@ def _parse_header(self):
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
wf_offset, wf_left_sweep, wf_sampling_rate))

# pre-loading spiking data
schannel_name = schannel_info.m_Name.decode()
self._spike_channel_cache[schannel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(schannel_name)

spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)

# creating event/epoch channel
Expand Down Expand Up @@ -366,6 +362,10 @@ def _spike_count(self, block_index, seg_index, spike_channel_index):
channel_name, channel_unit_id = channel_header['name'].split('.')
channel_unit_id = int(channel_unit_id)

# loading spike channel data on demand when not already cached
if channel_name not in self._spike_channel_cache:
self._spike_channel_cache[channel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(channel_name)

spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]
nb_spikes = np.count_nonzero(unit_ids == channel_unit_id)

Expand All @@ -376,6 +376,10 @@ def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index, t_s
channel_name, channel_unit_id = channel_header['name'].split('.')
channel_unit_id = int(channel_unit_id)

# loading spike channel data on demand when not already cached
if channel_name not in self._spike_channel_cache:
self._spike_channel_cache[channel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(channel_name)

spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]

if t_start is not None or t_stop is not None:
Expand Down Expand Up @@ -417,6 +421,10 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
channel_header = self.header['spike_channels'][spike_channel_index]
channel_name, channel_unit_id = channel_header['name'].split('.')

# loading spike channel data on demand when not already cached
if channel_name not in self._spike_channel_cache:
self._spike_channel_cache[channel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(channel_name)

spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]

if t_start is not None or t_stop is not None:
Expand Down

0 comments on commit 4eb5fc1

Please sign in to comment.