Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 38 additions & 40 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,10 @@ def compute_nmPLV(data: np.ndarray, sampling_rate: int, freq_range1: list, freq_
return con


def xwt(sig1: mne.Epochs, sig2: mne.Epochs, sfreq: Union[int, float],
freqs: Union[int, np.ndarray], analysis: str) -> np.ndarray:
def xwt(sig1: mne.Epochs, sig2: mne.Epochs,
freqs: Union[int, np.ndarray], n_cycles=5.0, mode: str) -> np.ndarray:
"""
Perfroms a cross wavelet transform on two signals.
Performs a cross wavelet transform on two signals.

Arguments:

Expand All @@ -773,43 +773,42 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, sfreq: Union[int, float],
sig2 : mne.Epochs
Signal (eg. EEG data) of second participant.

sfreq: int | float
Sampling frequency of the data in Hz.

freqs: int | float
Range of frequencies of interest in Hz.

analysis: str
Sets the type of analyses
mode: str
Sets the type of analyses.

Note:
This function relies on MNE's mne.time_frequency.morlet
and mne.time_frequency.tfr.cwt functions.

Returns:
data:
Wavelet results. The shape is (n_chans1, n_chans2, n_epochs, n_freqs, n_samples)
Wavelet results. The shape is (n_chans1, n_chans2, n_epochs, n_freqs, n_samples).
Wavelet transform coherence calculated according to Maraun & Kurths (2004)
"""

# Set parameters for the output
n_freqs = len(freqs)
sfreq = sig1.info['sfreq']
assert sig1.info['sfreq'] == sig2.info['sfreq'], "Sig1 et sig2 should have the same sfreq value."

n_epochs1, n_chans1, n_samples1 = sig1.get_data().shape
n_epochs2, n_chans2, n_samples2 = sig2.get_data().shape

# Set the mother wavelet
Ws = mne.time_frequency.tfr.morlet(sfreq, freqs, n_cycles=5.0, sigma=None,
zero_mean=True)
assert n_epochs1 == n_epochs2, "n_epochs1 and n_epochs2 should have the same number of epochs."
assert n_chans1 == n_chans2, "n_chans1 and n_chans2 should have the same number of channels."
assert n_samples1 == n_samples2, "n_samples1 and n_samples2 should have the same number of samples."

# Set parameters for the output
n_freqs = len(freqs)
n_epochs, n_chans1, n_samples = sig1.get_data().shape
n_epochs, n_chans2, n_samples = sig2.get_data().shape
cross_sigs = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1), dtype=complex) * np.nan
wcts = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1), dtype=complex) * np.nan

cross_sigs = np.zeros(
(n_chans1, n_chans2, n_freqs, n_samples),
dtype=complex) * np.nan
wcts = np.zeros(
(n_chans1, n_chans2, n_freqs, n_samples),
dtype=complex) * np.nan
# Set the mother wavelet
Ws = mne.time_frequency.tfr.morlet(sfreq, freqs,
n_cycles=n_cycles, sigma=None, zero_mean=True)

# perform a continuous wavelet transform on all epochs of each signal
# Perform a continuous wavelet transform on all epochs of each signal
for ind1, ch_label1 in enumerate(sig1.ch_names):
for ind2, ch_label2 in enumerate(sig2.ch_names):
# Extract the channel's data for both participants and apply cwt
Expand All @@ -819,27 +818,26 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, sfreq: Union[int, float],
cur_sig2 = np.squeeze(sig2.get_data(mne.pick_channels(sig2.ch_names, [ch_label2])))
out2 = mne.time_frequency.tfr.cwt(cur_sig2, Ws, use_fft=True,
mode='same', decim=1)
# Average across epochs
tfr_cwt1 = out1.mean(0)
tfr_cwt2 = out2.mean(0)

# Compute cross-spectrum
wps1 = tfr_cwt1 * tfr_cwt1.conj()
wps2 = tfr_cwt2 * tfr_cwt2.conj()
cross_sig = (out1 * out2.conj()).mean(0)
cross_sigs[ind1, ind2, :, :] = cross_sig
wps1 = out1 * out1.conj()
wps2 = out2 * out2.conj()
cross_sig = out1 * out2.conj()
cross_sigs[ind1, ind2, :, :, :] = cross_sig
coh = (cross_sig) / (np.sqrt(wps1*wps2))
abs_coh = np.abs(coh)
wct = (abs_coh - np.min(abs_coh)) / (np.max(abs_coh) - np.min(abs_coh))
wcts[ind1, ind2, :, :] = wct
if analysis == 'power':
data = np.abs((cross_sigs))
elif analysis == 'phase':
wcts[ind1, ind2, :, :, :] = wct

if mode == 'power':
data = np.abs(cross_sigs)
elif mode == 'phase':
data = np.angle(cross_sigs)
elif analysis == 'xwt':
elif mode == 'xwt':
data = cross_sigs
elif analysis == 'wtc':
data = wcts
elif mode == 'wtc':
data = wcts
else:
data = 'Please specify a valid analysis: power, phase, xwt, or wtc.'
data = 'Please specify a valid mode: power, phase, xwt, or wtc.'
print(data)
return data
return data