Skip to content

Commit

Permalink
Implement SpikeTrainList class, and make Segment.spiketrains an insta…
Browse files Browse the repository at this point in the history
…nce of this class.
  • Loading branch information
apdavison committed Jun 11, 2021
1 parent 0766940 commit ffc0e16
Show file tree
Hide file tree
Showing 8 changed files with 646 additions and 46 deletions.
14 changes: 12 additions & 2 deletions neo/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from copy import deepcopy
from neo.core.baseneo import BaseNeo, _reference_name, _container_name
from neo.core.spiketrain import SpikeTrain
from neo.core.spiketrainlist import SpikeTrainList


def unique_objs(objs):
Expand Down Expand Up @@ -83,7 +85,11 @@ def filterdata(data, targdict=None, objects=None, **kwargs):
results = [result for result in results if
result.__class__ in objects or
result.__class__.__name__ in objects]
return results

if results and all(isinstance(obj, SpikeTrain) for obj in results):
return SpikeTrainList(results)
else:
return results


class Container(BaseNeo):
Expand Down Expand Up @@ -411,7 +417,11 @@ def filter(self, targdict=None, data=True, container=False, recursive=True,
data = True
container = True

children = []
if objects == SpikeTrain:
children = SpikeTrainList()
else:
children = []

# get the objects we want
if data:
if recursive:
Expand Down
5 changes: 3 additions & 2 deletions neo/core/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from copy import deepcopy

from neo.core.container import Container
from neo.core.spiketrainlist import SpikeTrainList


class Segment(Container):
Expand Down Expand Up @@ -89,8 +90,8 @@ def __init__(self, name=None, description=None, file_origin=None,
Initialize a new :class:`Segment` instance.
'''
super().__init__(name=name, description=description,
file_origin=file_origin, **annotations)

file_origin=file_origin, **annotations)
self.spiketrains = SpikeTrainList(segment=self)
self.file_datetime = file_datetime
self.rec_datetime = rec_datetime
self.index = index
Expand Down
95 changes: 57 additions & 38 deletions neo/core/spiketrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,57 @@ def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None, copy=True,
return obj


def normalise_times_array(times, units, dtype=None, copy=True):
"""
Return a quantity array with the correct units.
There are four scenarios:
A. times (NumPy array), units given as string or Quantities units
B. times (Quantity array), units=None
C. times (Quantity), units given as string or Quantities units
D. times (NumPy array), units=None
In scenarios A-C we return a tuple (times as a Quantity array, dimensionality)
In scenario C, we rescale the original array to match `units`
In scenario D, we raise a ValueError
"""
if dtype is None:
if not hasattr(times, 'dtype'):
dtype = np.float
if units is None:
# No keyword units, so get from `times`
try:
dim = times.units.dimensionality
except AttributeError:
raise ValueError('you must specify units')
else:
if hasattr(units, 'dimensionality'):
dim = units.dimensionality
else:
dim = pq.quantity.validate_dimensionality(units)

if hasattr(times, 'dimensionality'):
if times.dimensionality.items() == dim.items():
units = None # units will be taken from times, avoids copying
else:
if not copy:
raise ValueError("cannot rescale and return view")
else:
# this is needed because of a bug in python-quantities
# see issue # 65 in python-quantities github
# remove this if it is fixed
times = times.rescale(dim)


# check to make sure the units are time
# this approach is orders of magnitude faster than comparing the
# reference dimensionality
if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0],
pq.UnitTime)):
ValueError("Units have dimensions %s, not [time]" % dim.simplified)
return pq.Quantity(times, units=units, dtype=dtype, copy=copy), dim


class SpikeTrain(DataObject):
'''
:class:`SpikeTrain` is a :class:`Quantity` array of spike times.
Expand Down Expand Up @@ -220,37 +271,7 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate
# len(times)!=0 has been used to workaround a bug occuring during neo import
raise ValueError("the number of waveforms should be equal to the number of spikes")

# Make sure units are consistent
# also get the dimensionality now since it is much faster to feed
# that to Quantity rather than a unit
if units is None:
# No keyword units, so get from `times`
try:
dim = times.units.dimensionality
except AttributeError:
raise ValueError('you must specify units')
else:
if hasattr(units, 'dimensionality'):
dim = units.dimensionality
else:
dim = pq.quantity.validate_dimensionality(units)

if hasattr(times, 'dimensionality'):
if times.dimensionality.items() == dim.items():
units = None # units will be taken from times, avoids copying
else:
if not copy:
raise ValueError("cannot rescale and return view")
else:
# this is needed because of a bug in python-quantities
# see issue # 65 in python-quantities github
# remove this if it is fixed
times = times.rescale(dim)

if dtype is None:
if not hasattr(times, 'dtype'):
dtype = np.float
elif hasattr(times, 'dtype') and times.dtype != dtype:
if dtype is not None and hasattr(times, 'dtype') and times.dtype != dtype:
if not copy:
raise ValueError("cannot change dtype and return view")

Expand All @@ -264,15 +285,13 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate
if hasattr(t_stop, 'dtype') and t_stop.dtype != times.dtype:
t_stop = t_stop.astype(times.dtype)

# check to make sure the units are time
# this approach is orders of magnitude faster than comparing the
# reference dimensionality
if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0],
pq.UnitTime)):
ValueError("Unit has dimensions %s, not [time]" % dim.simplified)
# Make sure units are consistent
# also get the dimensionality now since it is much faster to feed
# that to Quantity rather than a unit
times, dim = normalise_times_array(times, units, dtype, copy)

# Construct Quantity from data
obj = pq.Quantity(times, units=units, dtype=dtype, copy=copy).view(cls)
obj = times.view(cls)

# spiketrain times always need to be 1-dimensional
if len(obj.shape) > 1:
Expand Down
Loading

0 comments on commit ffc0e16

Please sign in to comment.