Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to MEArecRawIO for loading only recordings or only sorting data #1258

Merged
merged 8 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions neo/io/mearecio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ class MEArecIO(MEArecRawIO, BaseFromRaw):
__doc__ = MEArecRawIO.__doc__
mode = 'file'

def __init__(self, filename):
MEArecRawIO.__init__(self, filename=filename)
def __init__(self, filename, load_spiketrains=True, load_analogsignal=True):
MEArecRawIO.__init__(self,
filename=filename,
load_spiketrains=load_spiketrains,
load_analogsignal=load_analogsignal
)
BaseFromRaw.__init__(self, filename)
14 changes: 13 additions & 1 deletion neo/rawio/baserawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,18 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto
np.ndarray and are contiguous
:return: array with raw signal samples
"""

signal_streams = self.header['signal_streams']
signal_channels = self.header['signal_channels']
no_signal_streams = signal_streams.size == 0
no_channels = signal_channels.size == 0
if no_signal_streams or no_channels:
error_message = (
"get_analogsignal_chunk can't be called on a file with no signal streams or channels."
"Double check that your file contains signal streams and channels."
)
raise AttributeError(error_message)

stream_index = self._get_stream_index_from_arg(stream_index)
channel_indexes = self._get_channel_indexes(stream_index, channel_indexes,
channel_names, channel_ids)
Expand Down Expand Up @@ -579,7 +591,7 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=
channel_indexes=None, channel_names=None, channel_ids=None):
"""
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
returned by a call to get_analog_signal_chunk. The channels are specified either by
returned by a call to get_analogsignal_chunk. The channels are specified either by
JuliaSprenger marked this conversation as resolved.
Show resolved Hide resolved
channel_names, if provided, otherwise by channel_ids, if provided, otherwise by
channel_indexes, if provided, otherwise all channels are selected.

Expand Down
109 changes: 75 additions & 34 deletions neo/rawio/mearecrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ class MEArecRawIO(BaseRawIO):
"""
Class for "reading" fake data from a MEArec file.

This class provides a convenient way to read data from a MEArec file.

Parameters
----------
filename : str
The filename of the MEArec file to read.
load_spiketrains : bool, optional
Whether or not to load spike train data. Defaults to `True`.
load_analogsignal : bool, optional
Whether or not to load continuous recording data. Defaults to `True`.


Usage:
>>> import neo.rawio
>>> r = neo.rawio.MEArecRawIO(filename='mearec.h5')
Expand All @@ -36,52 +48,75 @@ class MEArecRawIO(BaseRawIO):
extensions = ['h5']
rawmode = 'one-file'

def __init__(self, filename=''):
def __init__(self, filename='', load_spiketrains=True, load_analogsignal=True):
BaseRawIO.__init__(self)
self.filename = filename

self.load_spiketrains = load_spiketrains
self.load_analogsignal = load_analogsignal

def _source_name(self):
return self.filename

def _parse_header(self):
load = ["channel_positions"]
if self.load_analogsignal:
load.append("recordings")
if self.load_spiketrains:
load.append("spiketrains")

import MEArec as mr
self._recgen = mr.load_recordings(recordings=self.filename, return_h5_objects=True,
check_suffix=False,
load=['recordings', 'spiketrains', 'channel_positions'],
load=load,
load_waveforms=False)
self._sampling_rate = self._recgen.info['recordings']['fs']
self._recordings = self._recgen.recordings
self._num_frames, self._num_channels = self._recordings.shape

signal_streams = np.array([('Signals', '0')], dtype=_signal_stream_dtype)

self.info_dict = deepcopy(self._recgen.info)
self.channel_positions = self._recgen.channel_positions
if self.load_analogsignal:
self._recordings = self._recgen.recordings
if self.load_spiketrains:
self._spiketrains = self._recgen.spiketrains

self._sampling_rate = self.info_dict['recordings']['fs']
self.duration_seconds = self.info_dict["recordings"]["duration"]
self._num_frames = int(self._sampling_rate * self.duration_seconds)
self._num_channels = self.channel_positions.shape[0]
self._dtype = self.info_dict["recordings"]["dtype"]

signals = [('Signals', '0')] if self.load_analogsignal else []
signal_streams = np.array(signals, dtype=_signal_stream_dtype)
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved


sig_channels = []
for c in range(self._num_channels):
ch_name = 'ch{}'.format(c)
chan_id = str(c + 1)
sr = self._sampling_rate # Hz
dtype = self._recordings.dtype
units = 'uV'
gain = 1.
offset = 0.
stream_id = '0'
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
if self.load_analogsignal:
for c in range(self._num_channels):
ch_name = 'ch{}'.format(c)
chan_id = str(c + 1)
sr = self._sampling_rate # Hz
dtype = self._dtype
units = 'uV'
gain = 1.
offset = 0.
stream_id = '0'
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))

sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved

# creating units channels
spike_channels = []
self._spiketrains = self._recgen.spiketrains
for c in range(len(self._spiketrains)):
unit_name = 'unit{}'.format(c)
unit_id = '#{}'.format(c)
# if spiketrains[c].waveforms is not None:
wf_units = ''
wf_gain = 1.
wf_offset = 0.
wf_left_sweep = 0
wf_sampling_rate = self._sampling_rate
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
wf_offset, wf_left_sweep, wf_sampling_rate))
if self.load_spiketrains:
for c in range(len(self._spiketrains)):
unit_name = 'unit{}'.format(c)
unit_id = '#{}'.format(c)
# if spiketrains[c].waveforms is not None:
wf_units = ''
wf_gain = 1.
wf_offset = 0.
wf_left_sweep = 0
wf_sampling_rate = self._sampling_rate
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
wf_offset, wf_left_sweep, wf_sampling_rate))

spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)

event_channels = []
Expand All @@ -98,7 +133,7 @@ def _parse_header(self):
self._generate_minimal_annotations()
for block_index in range(1):
bl_ann = self.raw_annotations['blocks'][block_index]
bl_ann['mearec_info'] = deepcopy(self._recgen.info)
bl_ann['mearec_info'] = self.info_dict

def _segment_t_start(self, block_index, seg_index):
all_starts = [[0.]]
Expand All @@ -119,6 +154,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):

def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
stream_index, channel_indexes):

if not self.load_analogsignal:
raise AttributeError("Recordings not loaded. Set load_analogsignal=True in MEArecRawIO constructor")

if i_start is None:
i_start = 0
if i_stop is None:
Expand All @@ -127,23 +166,25 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
if channel_indexes is None:
channel_indexes = slice(self._num_channels)
if isinstance(channel_indexes, slice):
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
else:
# sort channels because h5py neeeds sorted indexes
if np.any(np.diff(channel_indexes) < 0):
sorted_channel_indexes = np.sort(channel_indexes)
sorted_idx = np.array([list(sorted_channel_indexes).index(ch)
for ch in channel_indexes])
raw_signals = self._recgen.recordings[i_start:i_stop, sorted_channel_indexes]
raw_signals = self._recordings[i_start:i_stop, sorted_channel_indexes]
raw_signals = raw_signals[:, sorted_idx]
else:
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
return raw_signals

def _spike_count(self, block_index, seg_index, unit_index):

return len(self._spiketrains[unit_index])

def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):

spike_timestamps = self._spiketrains[unit_index].times.magnitude
if t_start is None:
t_start = self._segment_t_start(block_index, seg_index)
Expand Down
34 changes: 34 additions & 0 deletions neo/test/rawiotest/test_mearecrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,40 @@ class TestMEArecRawIO(BaseTestRawIO, unittest.TestCase, ):
'mearec/mearec_test_10s.h5'
]

def test_not_loading_recordings(self):

filename = self.entities_to_test[0]
filename = self.get_local_path(filename)
rawio = self.rawioclass(filename=filename, load_analogsignal=False)
rawio.parse_header()

# Test that rawio does not have a _recordings attribute
self.assertFalse(hasattr(rawio, '_recordings'))

# Test that calling get_spike_timestamps works
rawio.get_spike_timestamps()

# Test that caling anlogsignal chunk raises the right error
with self.assertRaises(AttributeError):
rawio.get_analogsignal_chunk()


def test_not_loading_spiketrain(self):

filename = self.entities_to_test[0]
filename = self.get_local_path(filename)
rawio = self.rawioclass(filename=filename, load_spiketrains=False)
rawio.parse_header()

# Test that rawio does not have a _spiketrains attribute
self.assertFalse(hasattr(rawio, '_spiketrains'))

# Test that calling analogsignal chunk works
rawio.get_analogsignal_chunk()

# Test that calling get_spike_timestamps raises an the right error
with self.assertRaises(AttributeError):
rawio.get_spike_timestamps()

if __name__ == "__main__":
unittest.main()