Skip to content

Commit 2e33165

Browse files
committed
Addec ASR and unit tests
1 parent 6e857f4 commit 2e33165

File tree

5 files changed

+874
-0
lines changed

5 files changed

+874
-0
lines changed

src/eegprep/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from .clean_drifts import clean_drifts
2020
from .clean_channels_nolocs import clean_channels_nolocs
2121
from .clean_channels import clean_channels
22+
from .clean_asr import clean_asr

src/eegprep/clean_asr.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import logging
2+
from typing import Dict, Any, Optional, Union, Tuple
3+
4+
import numpy as np
5+
6+
# Assuming these utilities exist and are correctly ported/placed
7+
from .utils.asr import asr_calibrate, asr_process
8+
try:
9+
# Allow optional dependency on clean_windows
10+
from .utils.clean_windows import clean_windows
11+
_has_clean_windows = True
12+
except ImportError:
13+
_has_clean_windows = False
14+
# Define a placeholder if clean_windows is not available
15+
def clean_windows(*args, **kwargs):
16+
raise ImportError("The 'clean_windows' function is required for automatic calibration data selection but was not found.")
17+
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
def clean_asr(
23+
EEG: Dict[str, Any],
24+
cutoff: float = 5.0,
25+
window_len: Optional[float] = None,
26+
step_size: Optional[int] = None,
27+
max_dims: float = 0.66,
28+
ref_maxbadchannels: Union[float, str, np.ndarray] = 0.075,
29+
ref_tolerances: Union[Tuple[float, float], str] = (-3.5, 5.5),
30+
ref_wndlen: Union[float, str] = 1.0,
31+
use_gpu: bool = False,
32+
useriemannian: bool = False,
33+
maxmem: Optional[int] = 64
34+
) -> Dict[str, Any]:
35+
"""Run the Artifact Subspace Reconstruction (ASR) method on EEG data.
36+
37+
This is an automated artifact rejection function that ensures that the data
38+
contains no events that have abnormally strong power; the subspaces on which
39+
those events occur are reconstructed (interpolated) based on the rest of the
40+
EEG signal during these time periods.
41+
42+
Args:
43+
EEG (Dict[str, Any]): EEG data structure. Expected fields:
44+
'data' (np.ndarray): Channels x Samples matrix.
45+
'srate' (float): Sampling rate in Hz.
46+
'nbchan' (int): Number of channels.
47+
It's assumed the data is zero-mean (e.g., high-pass filtered).
48+
cutoff (float, optional): Standard deviation cutoff for rejection. Data portions whose variance
49+
is larger than this threshold relative to the calibration data are
50+
considered artifactual and removed. Aggressive: 3, Default: 5, Conservative: 20.
51+
window_len (float, optional): Length of the statistics window in seconds. Should not be much longer
52+
than artifact timescale. Samples in window should be >= 1.5x channels.
53+
Default: max(0.5, 1.5 * nbchan / srate).
54+
step_size (int, optional): Step size for processing in samples. Reconstruction matrix updated every
55+
`step_size` samples. If None, defaults to window_len / 2 samples.
56+
max_dims (float, optional): Maximum dimensionality/fraction of dimensions to reconstruct. Default: 0.66.
57+
ref_maxbadchannels (Union[float, str, np.ndarray], optional): Parameter for automatic calibration data selection.
58+
float: Max fraction (0-1) of bad channels tolerated in a window for it to be used as calibration data. Lower is more aggressive (e.g., 0.05). Default: 0.075.
59+
'off': Use all data for calibration. Assumes artifact contamination < ~30-50%.
60+
np.ndarray: Directly provides the calibration data (channels x samples).
61+
ref_tolerances (Union[Tuple[float, float], str], optional): Power tolerances (lower, upper) in SDs from robust EEG power
62+
for a channel to be considered 'bad' during calibration data selection. Default: (-3.5, 5.5). Use 'off' to disable.
63+
ref_wndlen (Union[float, str], optional): Window length in seconds for calibration data selection granularity. Default: 1.0. Use 'off' to disable.
64+
use_gpu (bool, optional): Whether to try using GPU (requires compatible hardware and libraries, currently ignored). Default: False.
65+
useriemannian (bool, optional): Whether to use Riemannian ASR variant (NOT IMPLEMENTED). Default: False.
66+
maxmem (Optional[int], optional): Maximum memory in MB (passed to asr_calibrate/process, but chunking based on it is not implemented in Python port). Default: 64.
67+
68+
Returns:
69+
Dict[str, Any]: The EEG dictionary with the 'data' field containing the cleaned data.
70+
71+
Raises:
72+
NotImplementedError: If useriemannian is True.
73+
ImportError: If automatic calibration data selection is needed (`ref_maxbadchannels` is float) but `clean_windows` cannot be imported.
74+
ValueError: If input arguments are invalid or calibration fails critically.
75+
"""
76+
if useriemannian:
77+
raise NotImplementedError("The Riemannian ASR variant is not implemented in this Python port.")
78+
79+
if 'data' not in EEG or 'srate' not in EEG or 'nbchan' not in EEG:
80+
raise ValueError("EEG dictionary must contain 'data', 'srate', and 'nbchan'.")
81+
82+
data = np.asarray(EEG['data'], dtype=np.float64)
83+
srate = float(EEG['srate'])
84+
nbchan = int(EEG['nbchan'])
85+
C, S = data.shape
86+
87+
if C != nbchan:
88+
logger.warning(f"Mismatch between EEG['nbchan'] ({nbchan}) and EEG['data'].shape[0] ({C}). Using shape[0].")
89+
nbchan = C # Use the actual dimension from data
90+
91+
# --- Handle Defaults ---
92+
if window_len is None:
93+
window_len = max(0.5, 1.5 * nbchan / srate)
94+
95+
# --- Ensure Data Type ---
96+
# Already done with np.asarray above
97+
98+
# --- Determine Reference/Calibration Data ---
99+
ref_section_data = None
100+
if isinstance(ref_maxbadchannels, (int, float)) and isinstance(ref_tolerances, (tuple, list)) and isinstance(ref_wndlen, (int, float)):
101+
if not _has_clean_windows:
102+
raise ImportError("clean_windows is needed for automatic calibration data selection (ref_maxbadchannels is numeric) but was not found.")
103+
logger.info('Finding a clean section of the data for calibration...')
104+
try:
105+
# clean_windows is assumed to return the selected data array (C x S_clean)
106+
# It needs the EEG dict structure, similar to other clean_* funcs
107+
temp_EEG_for_cleanwin = EEG.copy()
108+
temp_EEG_for_cleanwin['data'] = data # ensure it has the float64 data
109+
cleaned_EEG = clean_windows(temp_EEG_for_cleanwin, ref_maxbadchannels, ref_tolerances, ref_wndlen)
110+
ref_section_data = np.asarray(cleaned_EEG['data'], dtype=np.float64)
111+
if ref_section_data.size == 0 or ref_section_data.shape[1] == 0:
112+
logger.warning("clean_windows returned no data. Falling back to using all data for calibration.")
113+
ref_section_data = data
114+
except Exception as e:
115+
logger.error(f"An error occurred during clean_windows: {e}")
116+
logger.warning("Could not automatically identify clean calibration data. Falling back to using the entire data for calibration.")
117+
ref_section_data = data
118+
elif isinstance(ref_maxbadchannels, str) and ref_maxbadchannels.lower() == 'off':
119+
logger.info("Using the entire data for calibration ('ref_maxbadchannels' set to 'off').")
120+
ref_section_data = data
121+
elif isinstance(ref_tolerances, str) and ref_tolerances.lower() == 'off':
122+
logger.info("Using the entire data for calibration ('ref_tolerances' set to 'off').")
123+
ref_section_data = data
124+
elif isinstance(ref_wndlen, str) and ref_wndlen.lower() == 'off':
125+
logger.info("Using the entire data for calibration ('ref_wndlen' set to 'off').")
126+
ref_section_data = data
127+
elif isinstance(ref_maxbadchannels, np.ndarray):
128+
logger.info("Using user-supplied data array for calibration.")
129+
ref_section_data = np.asarray(ref_maxbadchannels, dtype=np.float64)
130+
if ref_section_data.ndim != 2 or ref_section_data.shape[0] != C:
131+
raise ValueError(f"User-supplied calibration data must be a 2D array with shape ({C}, n_samples).")
132+
else:
133+
raise ValueError(f"Unsupported value or type for 'ref_maxbadchannels': {ref_maxbadchannels}. Must be float, 'off', or numpy array.")
134+
135+
# --- Calibrate ASR ---
136+
logger.info('Estimating ASR calibration statistics...')
137+
# The Python asr_calibrate uses its own defaults for blocksize, filters, etc.
138+
# We only pass the core parameters specified in the clean_asr call signature.
139+
try:
140+
state = asr_calibrate(ref_section_data, srate, cutoff=cutoff, maxmem=maxmem)
141+
except ValueError as e:
142+
# Catch specific errors like not enough calibration data
143+
raise ValueError(f"ASR calibration failed: {e}")
144+
except Exception as e:
145+
# Catch unexpected errors during calibration
146+
logger.exception("An unexpected error occurred during ASR calibration.")
147+
raise RuntimeError(f"ASR calibration failed unexpectedly: {e}")
148+
149+
del ref_section_data # Free memory
150+
151+
# --- Prepare for Processing ---
152+
if step_size is None:
153+
step_size = int(round(srate * window_len / 2)) # Samples
154+
155+
# --- Extrapolate Signal End ---
156+
# Required because asr_process needs lookahead data beyond the signal end
157+
# Based on: sig = [signal.data bsxfun(@minus,2*signal.data(:,end),signal.data(:,(end-1):-1:end-round(windowlen/2*signal.srate)))];
158+
N_extrap = int(round(window_len / 2 * srate))
159+
if N_extrap > 0:
160+
# Calculate indices for reflection, handling edge case where N_extrap >= S-1
161+
extrap_len = min(N_extrap, S - 1 if S > 1 else 0)
162+
if extrap_len > 0:
163+
# Indices from second-to-last sample back 'extrap_len' steps
164+
extrap_indices = np.arange(S - 2, S - extrap_len - 2, -1)
165+
# Reflect around the last sample: 2*last_sample - samples_before_last
166+
extrap_part = 2 * data[:, [-1]] - data[:, extrap_indices]
167+
sig = np.concatenate((data, extrap_part), axis=1)
168+
else: # Not enough data to extrapolate
169+
sig = data
170+
else: # No extrapolation needed
171+
sig = data
172+
173+
174+
# --- Process Signal using ASR ---
175+
logger.info('Applying ASR processing...')
176+
lookahead_sec = window_len / 2.0 # asr_process expects lookahead in seconds
177+
outdata, _ = asr_process(
178+
sig,
179+
srate,
180+
state,
181+
window_len=window_len,
182+
lookahead=lookahead_sec,
183+
step_size=step_size,
184+
max_dims=max_dims,
185+
max_mem=maxmem,
186+
use_gpu=use_gpu # Passed but ignored in current Python port
187+
)
188+
189+
# --- Finalize ---
190+
# asr_process returns the data adjusted for lookahead, matching original length S
191+
if outdata.shape[1] != S:
192+
logger.warning(f"Output data length ({outdata.shape[1]}) does not match input length ({S}). Truncating/padding output.")
193+
# This shouldn't happen if asr_process works correctly, but handle defensively
194+
if outdata.shape[1] > S:
195+
outdata = outdata[:, :S]
196+
else:
197+
padding = np.zeros((C, S - outdata.shape[1]))
198+
outdata = np.concatenate((outdata, padding), axis=1)
199+
200+
201+
EEG['data'] = outdata
202+
logger.info('ASR cleaning finished.')
203+
204+
return EEG

0 commit comments

Comments
 (0)