Skip to content

Commit

Permalink
Merge pull request NeuralEnsemble#1258 from h-mayorquin/improve_mearec
Browse files Browse the repository at this point in the history
Add option to `MEArecRawIO` for loading only recordings or only sorting data
  • Loading branch information
samuelgarcia authored May 2, 2023
2 parents a351f4e + 1863d55 commit 33b05b6
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 37 deletions.
8 changes: 6 additions & 2 deletions neo/io/mearecio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ class MEArecIO(MEArecRawIO, BaseFromRaw):
__doc__ = MEArecRawIO.__doc__
mode = 'file'

def __init__(self, filename):
MEArecRawIO.__init__(self, filename=filename)
def __init__(self, filename, load_spiketrains=True, load_analogsignal=True):
MEArecRawIO.__init__(self,
filename=filename,
load_spiketrains=load_spiketrains,
load_analogsignal=load_analogsignal
)
BaseFromRaw.__init__(self, filename)
14 changes: 13 additions & 1 deletion neo/rawio/baserawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,18 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto
np.ndarray and are contiguous
:return: array with raw signal samples
"""

signal_streams = self.header['signal_streams']
signal_channels = self.header['signal_channels']
no_signal_streams = signal_streams.size == 0
no_channels = signal_channels.size == 0
if no_signal_streams or no_channels:
error_message = (
"get_analogsignal_chunk can't be called on a file with no signal streams or channels."
"Double check that your file contains signal streams and channels."
)
raise AttributeError(error_message)

stream_index = self._get_stream_index_from_arg(stream_index)
channel_indexes = self._get_channel_indexes(stream_index, channel_indexes,
channel_names, channel_ids)
Expand Down Expand Up @@ -579,7 +591,7 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=
channel_indexes=None, channel_names=None, channel_ids=None):
"""
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
returned by a call to get_analog_signal_chunk. The channels are specified either by
returned by a call to get_analogsignal_chunk. The channels are specified either by
channel_names, if provided, otherwise by channel_ids, if provided, otherwise by
channel_indexes, if provided, otherwise all channels are selected.
Expand Down
109 changes: 75 additions & 34 deletions neo/rawio/mearecrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ class MEArecRawIO(BaseRawIO):
"""
Class for "reading" fake data from a MEArec file.
This class provides a convenient way to read data from a MEArec file.
Parameters
----------
filename : str
The filename of the MEArec file to read.
load_spiketrains : bool, optional
Whether or not to load spike train data. Defaults to `True`.
load_analogsignal : bool, optional
Whether or not to load continuous recording data. Defaults to `True`.
Usage:
>>> import neo.rawio
>>> r = neo.rawio.MEArecRawIO(filename='mearec.h5')
Expand All @@ -36,52 +48,75 @@ class MEArecRawIO(BaseRawIO):
extensions = ['h5']
rawmode = 'one-file'

def __init__(self, filename=''):
def __init__(self, filename='', load_spiketrains=True, load_analogsignal=True):
BaseRawIO.__init__(self)
self.filename = filename

self.load_spiketrains = load_spiketrains
self.load_analogsignal = load_analogsignal

def _source_name(self):
return self.filename

def _parse_header(self):
load = ["channel_positions"]
if self.load_analogsignal:
load.append("recordings")
if self.load_spiketrains:
load.append("spiketrains")

import MEArec as mr
self._recgen = mr.load_recordings(recordings=self.filename, return_h5_objects=True,
check_suffix=False,
load=['recordings', 'spiketrains', 'channel_positions'],
load=load,
load_waveforms=False)
self._sampling_rate = self._recgen.info['recordings']['fs']
self._recordings = self._recgen.recordings
self._num_frames, self._num_channels = self._recordings.shape

signal_streams = np.array([('Signals', '0')], dtype=_signal_stream_dtype)

self.info_dict = deepcopy(self._recgen.info)
self.channel_positions = self._recgen.channel_positions
if self.load_analogsignal:
self._recordings = self._recgen.recordings
if self.load_spiketrains:
self._spiketrains = self._recgen.spiketrains

self._sampling_rate = self.info_dict['recordings']['fs']
self.duration_seconds = self.info_dict["recordings"]["duration"]
self._num_frames = int(self._sampling_rate * self.duration_seconds)
self._num_channels = self.channel_positions.shape[0]
self._dtype = self.info_dict["recordings"]["dtype"]

signals = [('Signals', '0')] if self.load_analogsignal else []
signal_streams = np.array(signals, dtype=_signal_stream_dtype)


sig_channels = []
for c in range(self._num_channels):
ch_name = 'ch{}'.format(c)
chan_id = str(c + 1)
sr = self._sampling_rate # Hz
dtype = self._recordings.dtype
units = 'uV'
gain = 1.
offset = 0.
stream_id = '0'
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
if self.load_analogsignal:
for c in range(self._num_channels):
ch_name = 'ch{}'.format(c)
chan_id = str(c + 1)
sr = self._sampling_rate # Hz
dtype = self._dtype
units = 'uV'
gain = 1.
offset = 0.
stream_id = '0'
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))

sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)

# creating units channels
spike_channels = []
self._spiketrains = self._recgen.spiketrains
for c in range(len(self._spiketrains)):
unit_name = 'unit{}'.format(c)
unit_id = '#{}'.format(c)
# if spiketrains[c].waveforms is not None:
wf_units = ''
wf_gain = 1.
wf_offset = 0.
wf_left_sweep = 0
wf_sampling_rate = self._sampling_rate
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
wf_offset, wf_left_sweep, wf_sampling_rate))
if self.load_spiketrains:
for c in range(len(self._spiketrains)):
unit_name = 'unit{}'.format(c)
unit_id = '#{}'.format(c)
# if spiketrains[c].waveforms is not None:
wf_units = ''
wf_gain = 1.
wf_offset = 0.
wf_left_sweep = 0
wf_sampling_rate = self._sampling_rate
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
wf_offset, wf_left_sweep, wf_sampling_rate))

spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)

event_channels = []
Expand All @@ -98,7 +133,7 @@ def _parse_header(self):
self._generate_minimal_annotations()
for block_index in range(1):
bl_ann = self.raw_annotations['blocks'][block_index]
bl_ann['mearec_info'] = deepcopy(self._recgen.info)
bl_ann['mearec_info'] = self.info_dict

def _segment_t_start(self, block_index, seg_index):
all_starts = [[0.]]
Expand All @@ -119,6 +154,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):

def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
stream_index, channel_indexes):

if not self.load_analogsignal:
raise AttributeError("Recordings not loaded. Set load_analogsignal=True in MEArecRawIO constructor")

if i_start is None:
i_start = 0
if i_stop is None:
Expand All @@ -127,23 +166,25 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
if channel_indexes is None:
channel_indexes = slice(self._num_channels)
if isinstance(channel_indexes, slice):
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
else:
# sort channels because h5py neeeds sorted indexes
if np.any(np.diff(channel_indexes) < 0):
sorted_channel_indexes = np.sort(channel_indexes)
sorted_idx = np.array([list(sorted_channel_indexes).index(ch)
for ch in channel_indexes])
raw_signals = self._recgen.recordings[i_start:i_stop, sorted_channel_indexes]
raw_signals = self._recordings[i_start:i_stop, sorted_channel_indexes]
raw_signals = raw_signals[:, sorted_idx]
else:
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
return raw_signals

def _spike_count(self, block_index, seg_index, unit_index):

return len(self._spiketrains[unit_index])

def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):

spike_timestamps = self._spiketrains[unit_index].times.magnitude
if t_start is None:
t_start = self._segment_t_start(block_index, seg_index)
Expand Down
34 changes: 34 additions & 0 deletions neo/test/rawiotest/test_mearecrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,40 @@ class TestMEArecRawIO(BaseTestRawIO, unittest.TestCase, ):
'mearec/mearec_test_10s.h5'
]

def test_not_loading_recordings(self):

filename = self.entities_to_test[0]
filename = self.get_local_path(filename)
rawio = self.rawioclass(filename=filename, load_analogsignal=False)
rawio.parse_header()

# Test that rawio does not have a _recordings attribute
self.assertFalse(hasattr(rawio, '_recordings'))

# Test that calling get_spike_timestamps works
rawio.get_spike_timestamps()

# Test that caling anlogsignal chunk raises the right error
with self.assertRaises(AttributeError):
rawio.get_analogsignal_chunk()


def test_not_loading_spiketrain(self):

filename = self.entities_to_test[0]
filename = self.get_local_path(filename)
rawio = self.rawioclass(filename=filename, load_spiketrains=False)
rawio.parse_header()

# Test that rawio does not have a _spiketrains attribute
self.assertFalse(hasattr(rawio, '_spiketrains'))

# Test that calling analogsignal chunk works
rawio.get_analogsignal_chunk()

# Test that calling get_spike_timestamps raises an the right error
with self.assertRaises(AttributeError):
rawio.get_spike_timestamps()

if __name__ == "__main__":
unittest.main()

0 comments on commit 33b05b6

Please sign in to comment.