This project is a custom implementation of SwiFT (Swin 4D fMRI Transformer), a 4D model for fMRI analysis. This model, based on the Swin Transformer, is designed to effectively predict various biological and cognitive variables from fMRI scans. This implementation of the project focuses on providing an easy-to-use and flexible framework for fMRI analysis, removing dependencies of PyTorch Lightning and offering a lighter PyTorch-only implementation.
You can find the original research paper on SwiFT here. Feel free to contact the authors regarding the original project.
This custom implementation of SwiFT offers the following:
- Pure PyTorch Implementation: Removed all dependencies of PyTorch Lightning, resulting in a cleaner and easier codebase.
- Custom Training: the PyTorch Lightning-based training has been replaced with a simplified, custom training loop, providing a more direct and understandable approach to model training.
- Streamlined Data Handling: Simplified data preprocessing and loading pipelines for 4D fMRI datasets.
- Optimized for CUDA Devices: Easy and efficient utilization of CUDA GPUs for faster training and inference.
- Advanced Visualization Tools:
- t-SNE & UMAP Visualization: Integrated t-Distributed Stochastic Neighbor Embedding (t-SNE) and Uniform Manifold Approximation and Projection (UMAP) for visualizing high-dimensional model predictions.
- Integrated Gradients: Simplified implementation compared to the original project, making it easier to apply for model interpretability.
- Contrastive LanguageโImage Pre-training: Includes a standalone CLIP (contrastive learning) implementation for aligning paired modalities (e.g., age-gender).
- MLOps: Added pre-commit hooks, GitHub Actions for CI/CD, and support for WandB, MLflow, ZenML, and multi-GPU training (DDP).
- Deployment: FastAPI backend with Docker containerization and Docker Compose for full-stack app (Nginx frontend + SWIN backend).
This SwiFT model was tested on the ADNI dataset, preprocessed and aligned to MNI space. Various classification tasks were conducted to evaluate the model's performance.
-
Age Group Prediction:
- Task: Classifying individuals into "young" (< 69 years old) and "old" (> 78 years old) age groups, corresponding to the 1st and 3rd quartiles, respectively.
- Performance: Achieved 95.24% accuracy on the validation dataset.
-
Gender Prediction:
- Task: Classifying individuals based on their gender.
- Performance: Achieved 93.34% accuracy on the validation dataset.
-
Four-Target Classification:
- Task: Classifying individuals into four distinct groups: Young Female, Young Male, Old Female, Old Male.
- Performance: Achieved 89.2% accuracy on the validation dataset.
-
CLIP Version:
- Task: A CLIP-inspired version where the Swin model trained for age group prediction was aligned with one-hot encoded gender information.
- Performance: Achieved 97.46% accuracy on the validation dataset.
-
AD vs. CN Classification (Alzheimer's Disease vs. Cognitively Normal):
- Task: Classifying individuals as either having Alzheimer's Disease (AD) or being Cognitively Normal (CN).
- Performance: Achieved 97.2% accuracy on the validation dataset.
To run the code, please check the implementation of the ADNISwiFTDataset class in the SWIN.py file.
The dataset expects a CSV file with the following columns: ID, subject, fMRI_path, targets.
Please create such a file first, then run the script with the generate_data config option set to True.
Once the data preprocessing is done, you can train the model using the following command:
conda env create -n swin_env -f configs/env.yaml
conda activate swin_env
pip install -r requirements.txt
python SWIN.py # You can choose between SWIN.py (default) or SWIN_CLIP.py (CLIP version)Optional arguments for the training script:
-h, --help show this help message and exit
--task TASK Task to run (age_group or sex)
--cuda CUDA CUDA device to use (e.g., 0 for GPU 0)
--wandb WANDB Enable Weights and Biases (WandB) trackingVisualizations can be generated by executing the following files:
python -m results.visualization.viz_maps
python -m results.visualization.viz_gradientsIf you prefer running the project using Docker, a Dockerfile is provided for launching a FastAPI backend server, accessible at http://localhost:8000. You can build and run the container with the following commands:
docker build -t swin_app .
docker run -p 8000:8000 swin_appA docker-compose.yaml file is also included, which sets up both the backend (FastAPI server) and a frontend interface (served with Nginx). You can launch the full application with the following command:
docker compose up --buildAfter the services are up, you can access the application through your browser at http://localhost:80. This provides a nice interface to directly interact with the trained model.
docker_demo.mp4
We have integrated visualization capabilities to better understand the model's performance and the underlying data structure.
These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between age group classes.
t-SNE: Training Set
|
t-SNE: Validation Set
|
t-SNE: Test Set
|
UMAP: Training Set
|
UMAP: Validation Set
|
UMAP: Test Set
|
These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between gender classes.
t-SNE: Training Set
|
t-SNE: Validation Set
|
t-SNE: Test Set
|
UMAP: Training Set
|
UMAP: Validation Set
|
UMAP: Test Set
|
These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between the four combined classes: Young Female, Young Male, Old Female, and Old Male.
t-SNE: Training Set
|
t-SNE: Validation Set
|
t-SNE: Test Set
|
UMAP: Training Set
|
UMAP: Validation Set
|
UMAP: Test Set
|
These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between Alzheimer's Disease (AD) and Cognitively Normal (CN) subjects.
t-SNE: Training Set
|
t-SNE: Validation Set
|
t-SNE: Test Set
|
UMAP: Training Set
|
UMAP: Validation Set
|
UMAP: Test Set
|























