Skip to content

Commit

Permalink
Add function for detecting spindles in one call
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Jun 5, 2019
1 parent bb48a1d commit e6480a7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
2 changes: 2 additions & 0 deletions spindle_detector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# flake8: noqa
from .core import detect_spindle
52 changes: 52 additions & 0 deletions spindle_detector/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np

from hmmlearn import hmm
from spindle_detector.lfp_likelihood import (_DEFAULT_MULTITAPER_PARAMS,
estimate_spindle_band_power)


def atleast_2d(x):
return np.atleast_2d(x).T if x.ndim < 2 else x


def detect_spindle(lfps, sampling_frequency, start_time=0.0,
multitaper_params=_DEFAULT_MULTITAPER_PARAMS):
'''Finds spindle times using spectral power between 10-16 Hz and an HMM.
Parameters
----------
lfps : ndarray, shape (n_time, n_signals)
sampling_frequency : float
start_time : float
Returns
-------
time : ndarray, shape (n_time_windows,)
spindle_probability : ndarray, shape (n_time_windows, 2)
is_spindle : ndarray, shape (n_time_windows,)
model : hmmlearn.GaussianHMM instance
'''
time, spindle_band_power = estimate_spindle_band_power(
atleast_2d(lfps), sampling_frequency, start_time=start_time,
multitaper_params=multitaper_params)
startprob_prior = np.log(np.array([np.spacing(1), 1.0 - np.spacing(1)]))
model = hmm.GaussianHMM(n_components=2, covariance_type='full',
startprob_prior=startprob_prior, n_iter=100)
model = model.fit(np.log(spindle_band_power))

state_index = model.predict(np.log(spindle_band_power))

if (spindle_band_power[state_index == 0].mean() >
spindle_band_power[state_index == 1].mean()):
spindle_ind = 0
else:
spindle_ind = 1

is_spindle = np.zeros_like(state_index, dtype=np.bool)
is_spindle[state_index == spindle_ind] = True

spindle_probability = model.predict_proba(np.log(spindle_band_power))
spindle_probability = spindle_probability[:, spindle_ind]

return time, spindle_probability, is_spindle, model

0 comments on commit e6480a7

Please sign in to comment.