Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import statsmodels.stats.multitest
import copy
from collections import namedtuple
from typing import Union
from typing import Union, List, Tuple
from astropy.stats import circmean
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -218,7 +218,7 @@ def behav_corr(data: np.ndarray, behav: np.ndarray, data_name: str, behav_name:
return corr_tuple(r=r, pvalue=pvalue, strat=strat)


def indices_connectivity_intrabrain(epochs: mne.Epochs) -> list:
def indices_connectivity_intrabrain(epochs: mne.Epochs) -> List[Tuple[int, int]]:
"""
Computes indices for connectivity analysis between all EEG
channels for one participant. Can be used instead of
Expand Down Expand Up @@ -252,7 +252,7 @@ def indices_connectivity_intrabrain(epochs: mne.Epochs) -> list:
return channels


def indices_connectivity_interbrain(epoch_hyper: mne.Epochs) -> list:
def indices_connectivity_interbrain(epoch_hyper: mne.Epochs) -> List[Tuple[int, int]]:
"""
Computes indices for interbrain connectivity analyses between all EEG
sensors for 2 participants (merge data).
Expand Down Expand Up @@ -642,7 +642,7 @@ def compute_conn_mvar(complex_signal: np.ndarray, mvar_params: dict, ica_params:
return np.asarray(aux_3, dtype=d_type)


def compute_single_freq(data: np.ndarray, sampling_rate: int, freq_range: list) -> np.ndarray:
def compute_single_freq(data: np.ndarray, sampling_rate: int, freq_range: List[float]) -> np.ndarray:
"""
Computes analytic signal per frequency bin using the multitaper method.

Expand Down Expand Up @@ -718,7 +718,7 @@ def compute_freq_bands(data: np.ndarray, sampling_rate: int, freq_bands: dict, f
return complex_signal


def compute_nmPLV(data: np.ndarray, sampling_rate: int, freq_range1: list, freq_range2: list, **filter_options) -> np.ndarray:
def compute_nmPLV(data: np.ndarray, sampling_rate: int, freq_range1: List[float], freq_range2: List[float], **filter_options) -> np.ndarray:
"""
Computes the n:m PLV for a dyad with two different frequency ranges.

Expand Down
42 changes: 27 additions & 15 deletions hypyp/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@
import copy
import matplotlib.pyplot as plt
import mne
from autoreject import get_rejection_threshold, AutoReject
from autoreject import get_rejection_threshold, AutoReject, RejectLog
from mne.preprocessing import ICA, corrmap
from typing import List, Tuple, TypedDict

class DicAR(TypedDict):
"""
Epoch rejection info
"""
strategy: str
threshold: float
S1: float
S2: float
dyad: float

def filt(raw_S: list) -> list:

def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]:
"""
Filters list of raw data to remove slow drifts.

Expand All @@ -32,14 +43,14 @@ def filt(raw_S: list) -> list:
raws: list of high-pass filtered raws.
"""
# TODO: l_freq and h_freq as param
raws = []
raws: List[mne.io.Raw] = []
for raw in raw_S:
raws.append(mne.io.Raw.filter(raw, l_freq=2., h_freq=None))

return raws


def ICA_choice_comp(icas: list, epochs: list) -> list:
def ICA_choice_comp(icas: List[ICA], epochs: List[mne.Epochs]) -> List[mne.Epochs]:
"""
Plots Independent Components for each participant (calculated from Epochs),
let the user choose the relevant components for artifact rejection
Expand Down Expand Up @@ -88,13 +99,13 @@ def ICA_choice_comp(icas: list, epochs: list) -> list:
return cleaned_epochs_ICA


def ICA_apply(icas: int, subj_number: int, comp_number: int, epochs: list) -> list:
def ICA_apply(icas: List[ICA], subj_number: int, comp_number: int, epochs: List[mne.Epochs]) -> List[mne.Epochs]:
"""
Applies ICA with template model from 1 participant in the dyad.
See ICA_choice_comp for a detailed description of the parameters and output.
"""

cleaned_epochs_ICA = []
cleaned_epochs_ICA: List[ICA] = []
# selecting which ICs corresponding to the template
template_eog_component = icas[subj_number].get_components()[:, comp_number]

Expand Down Expand Up @@ -127,7 +138,7 @@ def ICA_apply(icas: int, subj_number: int, comp_number: int, epochs: list) -> li
return cleaned_epochs_ICA


def ICA_fit(epochs: list, n_components: int, method: str, fit_params: dict, random_state: int) -> list:
def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params: dict, random_state: int) -> List[ICA]:
"""
Computes global Autorejection to fit Independent Components Analysis
on Epochs, for each participant.
Expand Down Expand Up @@ -166,7 +177,7 @@ def ICA_fit(epochs: list, n_components: int, method: str, fit_params: dict, rand
icas: list of Independant Components for each participant (IC are MNE
objects, see MNE documentation for more details).
"""
icas = []
icas: List[ICA] = []
for epoch in epochs:
# per subj
# applying AR to find global rejection threshold
Expand All @@ -188,7 +199,7 @@ def ICA_fit(epochs: list, n_components: int, method: str, fit_params: dict, rand
return icas


def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float = 50.0, verbose: bool = False) -> list:
def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', threshold: float = 50.0, verbose: bool = False) -> Tuple[mne.Epochs, DicAR]:
"""
Applies local Autoreject to repair or reject bad epochs.

Expand All @@ -215,9 +226,10 @@ def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float =
dic_AR: dictionnary with the percentage of epochs rejection
for each subject and for the intersection of the them.
"""
bad_epochs_AR = []
AR = []
dic_AR = {}

bad_epochs_AR: List[RejectLog] = []
AR: List[AutoReject] = []
dic_AR: DicAR = {}
dic_AR['strategy'] = strategy
dic_AR['threshold'] = threshold

Expand Down Expand Up @@ -268,7 +280,7 @@ def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float =
dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100)

# picking good epochs for the two subj
cleaned_epochs_AR = []
cleaned_epochs_AR: List[mne.Epochs] = []
for clean_epochs in cleaned_epochs_ICA: # per subj
# keep a copy of the original data
clean_epochs_ep = copy.deepcopy(clean_epochs)
Expand All @@ -289,11 +301,11 @@ def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float =
print('%s percent of bad epochs' % dic_AR['dyad'])

# Vizualisation before after AR
evoked_before = []
evoked_before: List[mne.Evoked] = []
for clean_epochs in cleaned_epochs_ICA: # per subj
evoked_before.append(clean_epochs.average())

evoked_after_AR = []
evoked_after_AR: List[mne.Evoked] = []
for clean in cleaned_epochs_AR:
evoked_after_AR.append(clean.average())

Expand Down
9 changes: 5 additions & 4 deletions hypyp/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


from collections import namedtuple
from typing import List, Tuple
import numpy as np
import scipy
import matplotlib.pylab as plt
Expand Down Expand Up @@ -99,7 +100,7 @@ def statsCond(data: np.ndarray, epochs: mne.Epochs, n_permutations: int, alpha:
T_obs_plot=T_obs_plot)


def con_matrix(epochs: mne.Epochs, freqs_mean: list, draw: bool = False) -> tuple:
def con_matrix(epochs: mne.Epochs, freqs_mean: List[float], draw: bool = False) -> tuple:
"""
Computes a priori channel connectivity across space and frequencies.

Expand Down Expand Up @@ -158,7 +159,7 @@ def con_matrix(epochs: mne.Epochs, freqs_mean: list, draw: bool = False) -> tupl
ch_con_freq=ch_con_freq)


def metaconn_matrix_2brains(electrodes: list, ch_con: scipy.sparse.csr_matrix, freqs_mean: list, plot: bool = False) -> tuple:
def metaconn_matrix_2brains(electrodes: List[Tuple[int, int]], ch_con: scipy.sparse.csr_matrix, freqs_mean: List[float], plot: bool = False) -> tuple:
"""
Computes a priori connectivity across space and frequencies
between pairs of channels for which connectivity indices have
Expand Down Expand Up @@ -231,7 +232,7 @@ def metaconn_matrix_2brains(electrodes: list, ch_con: scipy.sparse.csr_matrix, f
metaconn_freq=metaconn_freq)


def metaconn_matrix(electrodes: list, ch_con: scipy.sparse.csr_matrix, freqs_mean: list) -> tuple:
def metaconn_matrix(electrodes: List[Tuple[int, int]], ch_con: scipy.sparse.csr_matrix, freqs_mean: List[float]) -> tuple:
"""
Computes a priori connectivity between pairs of sensors for which
connectivity indices have been calculated, across space and frequencies
Expand Down Expand Up @@ -363,7 +364,7 @@ def statscondCluster(data: list, freqs_mean: list, ch_con_freq: scipy.sparse.csr
F_obs_plot=F_obs_plot)


def statscluster(data: list, test: str, factor_levels: list, ch_con_freq: scipy.sparse.csr_matrix, tail: int, n_permutations: int, alpha: float = 0.05) -> tuple:
def statscluster(data: list, test: str, factor_levels: List[int], ch_con_freq: scipy.sparse.csr_matrix, tail: int, n_permutations: int, alpha: float = 0.05) -> tuple:
"""
Computes cluster-level statistical permutation test, corrected with
channel connectivity across space and frequencies to compare groups
Expand Down
9 changes: 5 additions & 4 deletions hypyp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""


from typing import Tuple, List
import numpy as np
import pandas as pd
from scipy.integrate import solve_ivp
Expand All @@ -20,7 +21,7 @@
from mne import create_info, EpochsArray


def create_epochs(raw_S1: mne.io.Raw, raw_S2: mne.io.Raw, duration: float) -> list:
def create_epochs(raw_S1: mne.io.Raw, raw_S2: mne.io.Raw, duration: float) -> Tuple[mne.Epochs, mne.Epochs]:
"""
Creates Epochs from Raws and vizualize Power Spectral Density (PSD)
on average Epochs (option).
Expand All @@ -43,8 +44,8 @@ def create_epochs(raw_S1: mne.io.Raw, raw_S2: mne.io.Raw, duration: float) -> li
Returns:
epoch_S1, epoch_S2: list of Epochs for each participant.
"""
epoch_S1 = []
epoch_S2 = []
epoch_S1: List[mne.Epochs] = []
epoch_S2: List[mne.Epochs] = []

for raw1, raw2 in zip(raw_S1, raw_S2):
# creating fixed events
Expand Down Expand Up @@ -259,7 +260,7 @@ def split(raw_merge: mne.io.Raw) -> mne.io.Raw:
return raw_1020_S1, raw_1020_S2


def concatenate_epochs(epoch_S1: mne.Epochs, epoch_S2: mne.Epochs) -> mne.Epochs:
def concatenate_epochs(epoch_S1: mne.Epochs, epoch_S2: mne.Epochs) -> Tuple[mne.Epochs, mne.Epochs]:
"""
Concatenates a list of Epochs in one Epochs object.

Expand Down