Skip to content

Update proposal for phylib #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 92 additions & 31 deletions phylib/io/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def read_array(path, mmap_mode=None):
if np.any(errors):
n = np.sum(errors)
n_tot = errors.size
logger.warning('%d/%d values are %s in %s, replacing by zero.', n, n_tot, w, path)
logger.warning(
'%d/%d values are %s in %s, replacing by zero.', n, n_tot, w, path
)
out[errors] = 0
return out

Expand All @@ -84,7 +86,9 @@ def from_sparse(data, cols, channel_ids):
"""
# The axis in the data that contains the channels.
if len(channel_ids) != len(np.unique(channel_ids)):
raise NotImplementedError('Multiple identical requested channels in from_sparse().')
raise NotImplementedError(
'Multiple identical requested channels in from_sparse().'
)
channel_axis = 1
shape = list(data.shape)
assert data.ndim >= 2
Expand All @@ -98,7 +102,7 @@ def from_sparse(data, cols, channel_ids):

# NumPy 2.0 compatibility fix: Create channel_ids array with sentinel value
# that works with both uint32 channel_ids and -1 sentinel
channel_ids_with_sentinel = np.concatenate([channel_ids.astype(np.int64), [-1]])
channel_ids_with_sentinel = np.concatenate([channel_ids, [-1]])

assert np.all(np.isin(c, channel_ids_with_sentinel))
# Convert column indices to relative indices given the specified
Expand Down Expand Up @@ -494,7 +498,9 @@ def _load_data(self):
# Template features.
self.sparse_template_features = self._load_template_features()
self.template_features = (
self.sparse_template_features.data if self.sparse_template_features else None
self.sparse_template_features.data
if self.sparse_template_features
else None
)

# Spike attributes.
Expand All @@ -504,7 +510,9 @@ def _load_data(self):
self.metadata = self._load_metadata()

def _find_path(self, *names, multiple_ok=True, mandatory=True):
full_paths = [l[0] for l in [list(self.dir_path.glob(name)) for name in names] if l]
full_paths = [
l[0] for l in [list(self.dir_path.glob(name)) for name in names] if l
]
path = _find_first_existing_path(*full_paths, multiple_ok=multiple_ok)
if mandatory and not path:
raise OSError(
Expand Down Expand Up @@ -573,7 +581,9 @@ def _load_channel_map(self):
return out

def _load_channel_positions(self):
path = self._find_path('channel_positions.npy', 'channels.localCoordinates*.npy')
path = self._find_path(
'channel_positions.npy', 'channels.localCoordinates*.npy'
)
out = self._read_array(path)
out = np.atleast_2d(out)
assert out.ndim == 2
Expand Down Expand Up @@ -603,7 +613,9 @@ def _load_traces(self, channel_map=None):
if not self.dat_path:
if os.environ.get('PHY_VIRTUAL_RAW_DATA', None): # pragma: no cover
n_samples = int((self.spike_times[-1] + 1) * self.sample_rate)
return RandomEphysReader(n_samples, len(channel_map), sample_rate=self.sample_rate)
return RandomEphysReader(
n_samples, len(channel_map), sample_rate=self.sample_rate
)
return
n = self.n_channels_dat
# self.dat_path could be any object accepted by get_ephys_reader().
Expand All @@ -620,7 +632,9 @@ def _load_traces(self, channel_map=None):

def _load_amplitudes(self):
try:
out = self._read_array(self._find_path('amplitudes.npy', 'spikes.amps*.npy'))
out = self._read_array(
self._find_path('amplitudes.npy', 'spikes.amps*.npy')
)
assert out.ndim == 1
return out
except OSError:
Expand All @@ -634,7 +648,9 @@ def _load_spike_templates(self):
out = out.astype(np.int32)
uc = np.unique(out)
if np.max(uc) - np.min(uc) + 1 != uc.size:
logger.warning('Unreferenced clusters found in templates (generally not a problem)')
logger.warning(
'Unreferenced clusters found in templates (generally not a problem)'
)
assert out.dtype in (np.uint16, np.uint32, np.int32, np.int64)
assert out.ndim == 1
return out
Expand Down Expand Up @@ -692,7 +708,9 @@ def _load_spike_samples(self):
if samples_path:
samples = self._read_array(samples_path)
else:
logger.info('Loading spikes.times.npy in seconds, converting to samples.')
logger.info(
'Loading spikes.times.npy in seconds, converting to samples.'
)
samples = np.round(times * self.sample_rate).astype(np.uint64)
assert samples.ndim == times.ndim == 1
return samples, times
Expand All @@ -707,7 +725,9 @@ def _load_spike_waveforms(self): # pragma: no cover
'on the fly from the raw data as needed.'
)
return
logger.debug('Loading spikes subset waveforms to avoid fetching waveforms from raw data.')
logger.debug(
'Loading spikes subset waveforms to avoid fetching waveforms from raw data.'
)
try:
return Bunch(
waveforms=self._read_array(path, mmap_mode='r'),
Expand Down Expand Up @@ -752,7 +772,9 @@ def _load_templates(self):
# That means templates.npy is considered as a dense array.
# Proper fix would be to save templates.npy as a true sparse array, with proper
# template_ind.npy (without an s).
path = self._find_path('template_ind.npy', 'templates.waveformsChannels*.npy')
path = self._find_path(
'template_ind.npy', 'templates.waveformsChannels*.npy'
)
cols = self._read_array(path)
if cols.ndim != 2: # pragma: no cover
cols = np.atleast_2d(cols).T
Expand Down Expand Up @@ -815,7 +837,9 @@ def _load_features(self):
return

try:
cols = self._read_array(self._find_path('pc_feature_ind.npy'), mmap_mode='r')
cols = self._read_array(
self._find_path('pc_feature_ind.npy'), mmap_mode='r'
)
logger.debug('Features are sparse.')
if cols.ndim == 1: # pragma: no cover
# Deal with npcs = 1.
Expand All @@ -838,7 +862,9 @@ def _load_template_features(self):
# Sparse structure: regular array with row and col indices.
try:
logger.debug('Loading template features.')
data = self._read_array(self._find_path('template_features.npy'), mmap_mode='r')
data = self._read_array(
self._find_path('template_features.npy'), mmap_mode='r'
)
assert data.dtype in (np.float32, np.float64)
assert data.ndim == 2
n_spikes, n_channels_loc = data.shape
Expand Down Expand Up @@ -877,7 +903,9 @@ def _find_best_channels(self, template, amplitude_threshold=None):
max_amp = amplitude[best_channel]
# Find the channels X% peak.
amplitude_threshold = (
amplitude_threshold if amplitude_threshold is not None else self.amplitude_threshold
amplitude_threshold
if amplitude_threshold is not None
else self.amplitude_threshold
)
peak_channels = np.nonzero(amplitude >= amplitude_threshold * max_amp)[0]
# Find N closest channels.
Expand Down Expand Up @@ -919,7 +947,9 @@ def _get_template_dense(
if not self.sparse_templates:
return
template_w = self.sparse_templates.data[template_id, ...]
template = self._unwhiten(template_w).astype(np.float32) if unwhiten else template_w
template = (
self._unwhiten(template_w).astype(np.float32) if unwhiten else template_w
)
assert template.ndim == 2
channel_ids_, amplitude, best_channel = self._find_best_channels(
template, amplitude_threshold=amplitude_threshold
Expand Down Expand Up @@ -956,7 +986,11 @@ def _get_template_sparse(self, template_id, unwhiten=True):
channel_ids = channel_ids.astype(np.uint32)

# Unwhiten.
template = self._unwhiten(template_w, channel_ids=channel_ids) if unwhiten else template_w
template = (
self._unwhiten(template_w, channel_ids=channel_ids)
if unwhiten
else template_w
)
template = template.astype(np.float32)
assert template.ndim == 2
assert template.shape[1] == len(channel_ids)
Expand All @@ -976,23 +1010,29 @@ def _get_template_sparse(self, template_id, unwhiten=True):

def get_merge_map(self):
""" "Gets the maps of merges and splits between spikes.clusters and spikes.templates"""
inverse_mapping_dict = {key: [] for key in range(np.max(self.spike_clusters) + 1)}
inverse_mapping_dict = {
key: [] for key in range(np.max(self.spike_clusters) + 1)
}
for temp in np.unique(self.spike_templates):
idx = np.where(self.spike_templates == temp)[0]
new_idx = self.spike_clusters[idx]
mapping = np.unique(new_idx)
for n in mapping:
inverse_mapping_dict[n].append(temp)

nan_idx = np.array([idx for idx, val in inverse_mapping_dict.items() if len(val) == 0])
nan_idx = np.array(
[idx for idx, val in inverse_mapping_dict.items() if len(val) == 0]
)

return inverse_mapping_dict, nan_idx

# --------------------------------------------------------------------------
# Data access methods
# --------------------------------------------------------------------------

def get_template(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True):
def get_template(
self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True
):
"""Get data about a template."""
if self.sparse_templates and self.sparse_templates.cols is not None:
return self._get_template_sparse(template_id, unwhiten=unwhiten)
Expand Down Expand Up @@ -1112,7 +1152,9 @@ def get_template_features(self, spike_ids):
cols = tf.cols[self.spike_templates[spike_ids]]
else:
cols = np.tile(np.arange(n_templates_loc), (len(spike_ids), 1))
template_features = from_sparse(template_features, cols, np.arange(self.n_templates))
template_features = from_sparse(
template_features, cols, np.arange(self.n_templates)
)

assert template_features.shape[0] == ns
return template_features
Expand All @@ -1125,14 +1167,21 @@ def get_depths(self):
c = 0
spikes_depths = np.zeros_like(self.spike_times) * np.nan
nspi = spikes_depths.shape[0]
if self.sparse_features is None or self.sparse_features.data.shape[0] != self.n_spikes:
if (
self.sparse_features is None
or self.sparse_features.data.shape[0] != self.n_spikes
):
return None
while True:
ispi = np.arange(c, min(c + nbatch, nspi))
# take only first component
features = self.sparse_features.data[ispi, :, 0]
features = np.maximum(features, 0) ** 2 # takes only positive values into account
ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.uint32)
features = (
np.maximum(features, 0) ** 2
) # takes only positive values into account
ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(
np.uint32
)
# features = np.square(self.sparse_features.data[ispi, :, 0])
# ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64)
ypos = self.channel_positions[ichannels, 1]
Expand Down Expand Up @@ -1178,7 +1227,9 @@ def get_amplitudes_true(self, sample2unit=1.0, use='templates'):
templates_wfs[n, :, :] = np.matmul(sparse.data[n, :, :], self.wmi)

# The amplitude on each channel is the positive peak minus the negative
templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1)
templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(
templates_wfs, axis=1
)

# The template arbitrary unit amplitude is the amplitude of its largest channel
# (but see below for true tempAmps)
Expand All @@ -1187,10 +1238,13 @@ def get_amplitudes_true(self, sample2unit=1.0, use='templates'):

with np.errstate(divide='ignore', invalid='ignore'):
# take the average spike amplitude per template
templates_amps_v = np.bincount(spikes, weights=spike_amps) / np.bincount(spikes)
templates_amps_v = np.bincount(spikes, weights=spike_amps) / np.bincount(
spikes
)
# scale back the template according to the spikes units
templates_physical_unit = (
templates_wfs * (templates_amps_v / templates_amps_au)[:, np.newaxis, np.newaxis]
templates_wfs
* (templates_amps_v / templates_amps_au)[:, np.newaxis, np.newaxis]
)

return (
Expand Down Expand Up @@ -1277,7 +1331,8 @@ def get_cluster_mean_waveforms(self, cluster_id, unwhiten=True):
channel_ids = template.channel_ids
# Get all templates from which this cluster stems from.
templates = [
self.get_template(template_id, unwhiten=unwhiten) for template_id in template_ids
self.get_template(template_id, unwhiten=unwhiten)
for template_id in template_ids
]
# Construct the waveforms array.
ns = self.n_samples_waveforms
Expand Down Expand Up @@ -1318,7 +1373,9 @@ def _channels(self, sparse):
tmp = sparse.data
n_templates, n_samples, n_channels = tmp.shape
if sparse.cols is None:
template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
template_peak_channels = np.argmax(
tmp.max(axis=1) - tmp.min(axis=1), axis=1
)
else:
# when the templates are sparse, the first channel is the highest amplitude channel
template_peak_channels = sparse.cols[:, 0]
Expand Down Expand Up @@ -1405,7 +1462,9 @@ def save_metadata(self, name, values):

def save_spike_clusters(self, spike_clusters):
"""Save the spike clusters."""
path = self._find_path('spike_clusters.npy', 'spikes.clusters.npy', multiple_ok=False)
path = self._find_path(
'spike_clusters.npy', 'spikes.clusters.npy', multiple_ok=False
)
logger.debug('Save spike clusters to `%s`.', path)
np.save(path, spike_clusters)

Expand Down Expand Up @@ -1500,7 +1559,9 @@ def get_template_params(params_path):

if isinstance(params['dat_path'], str):
params['dat_path'] = [params['dat_path']]
params['dat_path'] = [_make_abs_path(_, params['dir_path']) for _ in params['dat_path']]
params['dat_path'] = [
_make_abs_path(_, params['dir_path']) for _ in params['dat_path']
]
return params


Expand Down