Skip to content

Commit

Permalink
SPADE: New way to count patterns for multiple testing (#347)
Browse files Browse the repository at this point in the history
Co-authored-by: stellalessandra <a.stella@fz-juelich.de>
Co-authored-by: p-bouss <peter.bouss@googlemail.com>
  • Loading branch information
3 people authored Sep 7, 2020
1 parent 9d869b2 commit 6f342aa
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 22 deletions.
43 changes: 30 additions & 13 deletions elephant/spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,9 +1515,24 @@ def _mask_pvalue_spectrum(pv_spec, concepts, spectrum, winlen):
signatures = {(len(concept[0]), len(concept[1]),
max(np.array(concept[0]) % winlen))
for concept in concepts}
mask = np.array([tuple(pvs[:-1]) in signatures
and not np.isclose(pvs[-1], [1])
for pvs in pv_spec])
mask = np.zeros(len(pv_spec), dtype=bool)
for index, pv_entry in enumerate(pv_spec):
if tuple(pv_entry[:-1]) in signatures \
and not np.isclose(pv_entry[-1], [1]):
# select the highest number of occurrences for size and duration
mask[index] = True
if mask[index-1]:
if spectrum == '#':
size = pv_spec[index][0]
prev_size = pv_spec[index-1][0]
if prev_size == size:
mask[index-1] = False
else:
size, duration = pv_spec[index][[0, 2]]
prev_size, prev_duration = pv_spec[index-1][[0, 2]]
if prev_size == size and duration == prev_duration:
mask[index-1] = False

return mask


Expand Down Expand Up @@ -1624,6 +1639,7 @@ def test_signature_significance(pv_spec, concepts, alpha, winlen,
pv_spec = np.array(pv_spec)
mask = _mask_pvalue_spectrum(pv_spec, concepts, spectrum, winlen)
pvalues = pv_spec[:, -1]

pvalues_totest = pvalues[mask]

# Initialize test array to False
Expand All @@ -1646,34 +1662,35 @@ def test_signature_significance(pv_spec, concepts, alpha, winlen,
method=corr)[0]

# assign each corrected pvalue to its corresponding entry
# this breaks
for index, value in zip(mask.nonzero()[0], tests_selected):
tests[index] = value

# Return the specified results:
if spectrum == '#':
if report == 'spectrum':
sig_spectrum = [(size, supp, test)
for (size, supp, pv), test in zip(pv_spec, tests)]
sig_spectrum = [(size, occ, test)
for (size, occ, pv), test in zip(pv_spec, tests)]
elif report == 'significant':
sig_spectrum = [(size, supp) for ((size, supp, pv), test)
sig_spectrum = [(size, occ) for ((size, occ, pv), test)
in zip(pv_spec, tests) if test]
else: # report == 'non_significant'
sig_spectrum = [(size, supp)
for ((size, supp, pv), test) in zip(pv_spec, tests)
sig_spectrum = [(size, occ)
for ((size, occ, pv), test) in zip(pv_spec, tests)
if not test]

else: # spectrum == '3d#'
if report == 'spectrum':
sig_spectrum =\
[(size, supp, l, test)
for (size, supp, l, pv), test in zip(pv_spec, tests)]
[(size, occ, l, test)
for (size, occ, l, pv), test in zip(pv_spec, tests)]
elif report == 'significant':
sig_spectrum = [(size, supp, l) for ((size, supp, l, pv), test)
sig_spectrum = [(size, occ, l) for ((size, occ, l, pv), test)
in zip(pv_spec, tests) if test]
else: # report == 'non_significant'
sig_spectrum =\
[(size, supp, l)
for ((size, supp, l, pv), test) in zip(pv_spec, tests)
[(size, occ, l)
for ((size, occ, l, pv), test) in zip(pv_spec, tests)
if not test]
return sig_spectrum

Expand Down
24 changes: 15 additions & 9 deletions elephant/spike_train_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None,
# Return the surrogates as list of neo.SpikeTrain
return [neo.SpikeTrain(
train, t_start=t_start, t_stop=t_stop)
for train in dithered_spiketrains]
for train in dithered_spiketrains]


@deprecated_alias(n='n_surrogates')
Expand Down Expand Up @@ -638,7 +638,7 @@ def bin_shuffling(


class JointISI(object):
"""
r"""
The class :class:`JointISI` is implemented for Joint-ISI dithering
as a continuation of the ideas of [1]_ and [2]_.
Expand Down Expand Up @@ -728,8 +728,10 @@ def __init__(self,
cutoff=True,
refractory_period=4. * pq.ms,
isi_dithering=False):

if not isinstance(spiketrain, neo.SpikeTrain):
raise TypeError('spiketrain must be of type neo.SpikeTrain')

self.spiketrain = spiketrain
self.truncation_limit = self._get_magnitude(truncation_limit)
self.n_bins = n_bins
Expand Down Expand Up @@ -1096,6 +1098,7 @@ def trial_shifting(spiketrains, dither, n_surrogates=1):
It shifts by a random uniform amount independently different trials,
which are the elements of a list of spiketrains.
The shifting is done independently for each surrogate.
Parameters
Expand Down Expand Up @@ -1169,9 +1172,11 @@ def _trial_shifting_of_concatenated_spiketrain(
spiketrain, dither, trial_length, trial_separation, n_surrogates=1):
"""
Generates surrogates of a spike train by trial shifting.
It shifts by a random uniform amount independently different trials,
individuated by the `trial_length` and the possible buffering period
`trial_separation` present in between trials.
The shifting is done independently for each surrogate.
Parameters
Expand All @@ -1180,7 +1185,7 @@ def _trial_shifting_of_concatenated_spiketrain(
A single spike train, where the trials are concatenated.
dither : pq.Quantity
Amount of dithering.
trial_length : pq.Quantity
trial_length : pq.Quantity
The length of the single-trial spiketrain.
trial_separation : pq.Quantity
Buffering in between trials in the concatenation of the spiketrain.
Expand Down Expand Up @@ -1213,10 +1218,10 @@ def _trial_shifting_of_concatenated_spiketrain(
spiketrains, dither, t_starts, t_stops, n_surrogates)

surrogate_spiketrains = [neo.SpikeTrain(
np.hstack(surrogate_spiketrain) * pq.s,
t_start=t_start * pq.s,
t_stop=t_stop * pq.s,
units=units)
np.hstack(surrogate_spiketrain) * pq.s,
t_start=t_start * pq.s,
t_stop=t_stop * pq.s,
units=units)
for surrogate_spiketrain in surrogate_spiketrains]
return surrogate_spiketrains

Expand Down Expand Up @@ -1291,7 +1296,8 @@ def surrogates(
ValueError
If `method` is not one of the surrogate methods defined in this module.
If `dt` is None and `method` is not 'randomise_spikes' nor 'shuffle_isis'.
If `dt` is None and `method` is not 'randomise_spikes' nor
'shuffle_isis'.
"""

if isinstance(spiketrain, list):
Expand Down Expand Up @@ -1356,7 +1362,7 @@ def surrogates(
t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop,
units=spiketrain.units)
for binned_surrogate in binned_surrogates]
for binned_surrogate in binned_surrogates]
return surrogate_spiketrains
# surr_method is 'joint_isi_dithering' or isi_dithering:
return method(n_surrogates)

0 comments on commit 6f342aa

Please sign in to comment.