Code for "Stabilizing Subject Transfer in EEG Classification with Divergence Estimation" by Niklas Smedemark-Margulies, Ye Wang, Toshiaki Koike-Akino, Jing Liu, Kieran Parsons, Yunus Bicer, and Deniz Erdogmus.
In this work, we provide two new censoring methods based on density ratio estimation and Wasserstein-1 divergence. We neural network critic functions to estimate a regularization penalty based on one of these methods, and evaluate their performance against an adversarial classifier baseline using a large EEG benchmark dataset.
Setup project using:
make
The default Makefile target will setup the project (if not already setup) and format/lint all files.
Note that it uses python3.8
; if this is not available or python from a specific path should be used, change BASE_PYTHON
to a suitable path (e.g. /usr/bin/python3.8
).
The project should work with roughly python3.7 or newer.
Makefile targets:
make setup
- create virtualenv, install python packages, and install pre-commit hooksmake lint
- format and lint all files, using the same rules that are part ofpre-commit
make test
- there are unit tests for a subset of code, just some sanity checks. Not all tests pass, and some tests expect that data has been downloadedmake destroy-setup
- delete the python virtual environment in order to create it again from scratch
Use source venv/bin/activate
for all commands below to activate virtual environment.
Run pre-processing script using:
python scripts/preprocess_thu_data.subset.py
The preprocessing script does the following steps:
- Downloads raw files for THU RSVP benchmark dataset
- Applies signal filtering/pre-processing steps (notch filter, bandpass filter, downsample) and extract labeled trials using:
NOTE -
transform = get_default_transform(...) thu_rsvp = THU_RSVP_Dataset(..., transform=transform)
force_extract=False
will just re-use previously extracted trials (for speed). In order to adjust the signal filtering stage, change the transform and useforce_extract=True
:new_transform = get_default_transform(...) # Suppose you make changes here... thu_rsvp = THU_RSVP_Dataset(..., transform=new_transform, force_extract=True, # ...Then you must set this flag )
- Divides into several train/val/test folds and saves as
*.npy
files for fast loading during training runs Note that src/data.py has a hard-coded path for the location of these*.npy
files. (Currentlypath = DATA_DIR / "thu_rsvp" / "subset" / f"fold{fold_idx}"
) If using different train/val/test splits be sure to update this line in src/data.py.
The first experiment uses the final checkpoint for testing:
python scripts/run_overfitting_experiment.long.py
The second experiment uses the best val checkpoint for testing:
python scripts/run_censoring_experiment.py
NOTE - these grids are very large and slurm queue can only hold ~10K jobs.
- Set
DRY_RUN=True
at top of script to see jobs that will be run and try a few manually - Comment out subsets of params within script as needed.
- Try not to run baseline multiple times (it wastes some time, though it should be fine; can select only final baseline run when loading results, but wastes some time)
Results will go into results/<FOLDER>
with different name depending on script.
results/censoring__last__100epoch
(we used this folder name for the main results in the paper)results/censoring
(we used this folder name for results that combine censoring with early stopping in paper appendix)
Plots are created in two steps, to allow for more easy development. First, scrape results into a pickle file:
python scripts/make_plots_1.py --experiment_name <NAME> # <NAME> is "censoring" or "censoring__last__100epoch"
Next create plots from pickle file:
python scripts/make_plots_2.py --experiment_name <NAME> # <NAME> is "censoring" or "censoring__last__100epoch"
As mentioned above, since the full THU dataset is very large, it is pre-processed into individual numpy files for train/val/test so that data can be quickly loaded during a single training run.
The preprocessing is done using scripts/preprocess_thu_data.subset.py
- Data is split by subjects: 24 train subjects / 4 val / 4 test.
- Data is also down-sampled so that the class label distribution 1 target : 10 non-target.
(The original class ratio is ~1 target : 60 non-target)
Other dataset considerations: - Using all available subjects (54 train / 5 val / 5 test), and no class downsampling was problematic. Training was too slow for very large experiment grids (~1 hour per run), with occasional OOM issues. - Class downsampling seems to be important (the raw class ratio is ~1 target : 60 non-target). With excessively imbalanced classes, the potential benefit of regularization is obscured by the task difficulty.
To read our paper, see: https://arxiv.org/pdf/2310.08762.pdf
If you use the software, please cite the following:
@article{smedemark2023stabilizing,
title={Stabilizing Subject Transfer in EEG Classification with Divergence Estimation},
author={
Smedemark-Margulies, Niklas and
Wang, Ye and
Koike-Akino, Toshiaki and
Liu, Jing and
Parsons, Kieran and
Bicer, Yunus and
Erdogmus, Deniz
},
journal={arXiv preprint arXiv:2310.08762},
year={2023}
}
Ye Wang yewang@merl.com
See CONTRIBUTING.md for our policy on contributions.
Released under AGPL-3.0-or-later
license, as found in the LICENSE.md file.
All files:
Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL).
SPDX-License-Identifier: AGPL-3.0-or-later