Skip to content

gillet-thomas/SWIN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

95 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

SwiFT: Swin 4D fMRI Transformer

Python PyTorch

๐Ÿ“Œ Introduction

This project is a custom implementation of SwiFT (Swin 4D fMRI Transformer), a 4D model for fMRI analysis. This model, based on the Swin Transformer, is designed to effectively predict various biological and cognitive variables from fMRI scans. This implementation of the project focuses on providing an easy-to-use and flexible framework for fMRI analysis, removing dependencies of PyTorch Lightning and offering a lighter PyTorch-only implementation.

You can find the original research paper on SwiFT here. Feel free to contact the authors regarding the original project.


๐Ÿš€ Key Features and Improvements

This custom implementation of SwiFT offers the following:

  • Pure PyTorch Implementation: Removed all dependencies of PyTorch Lightning, resulting in a cleaner and easier codebase.
  • Custom Training: the PyTorch Lightning-based training has been replaced with a simplified, custom training loop, providing a more direct and understandable approach to model training.
  • Streamlined Data Handling: Simplified data preprocessing and loading pipelines for 4D fMRI datasets.
  • Optimized for CUDA Devices: Easy and efficient utilization of CUDA GPUs for faster training and inference.
  • Advanced Visualization Tools:
    • t-SNE & UMAP Visualization: Integrated t-Distributed Stochastic Neighbor Embedding (t-SNE) and Uniform Manifold Approximation and Projection (UMAP) for visualizing high-dimensional model predictions.
    • Integrated Gradients: Simplified implementation compared to the original project, making it easier to apply for model interpretability.
  • Contrastive Languageโ€“Image Pre-training: Includes a standalone CLIP (contrastive learning) implementation for aligning paired modalities (e.g., age-gender).
  • MLOps: Added pre-commit hooks, GitHub Actions for CI/CD, and support for WandB, MLflow, ZenML, and multi-GPU training (DDP).
  • Deployment: FastAPI backend with Docker containerization and Docker Compose for full-stack app (Nginx frontend + SWIN backend).

Experiments and Results

This SwiFT model was tested on the ADNI dataset, preprocessed and aligned to MNI space. Various classification tasks were conducted to evaluate the model's performance.

  1. Age Group Prediction:

    • Task: Classifying individuals into "young" (< 69 years old) and "old" (> 78 years old) age groups, corresponding to the 1st and 3rd quartiles, respectively.
    • Performance: Achieved 95.24% accuracy on the validation dataset.
  2. Gender Prediction:

    • Task: Classifying individuals based on their gender.
    • Performance: Achieved 93.34% accuracy on the validation dataset.
  3. Four-Target Classification:

    • Task: Classifying individuals into four distinct groups: Young Female, Young Male, Old Female, Old Male.
    • Performance: Achieved 89.2% accuracy on the validation dataset.
  4. CLIP Version:

    • Task: A CLIP-inspired version where the Swin model trained for age group prediction was aligned with one-hot encoded gender information.
    • Performance: Achieved 97.46% accuracy on the validation dataset.
  5. AD vs. CN Classification (Alzheimer's Disease vs. Cognitively Normal):

    • Task: Classifying individuals as either having Alzheimer's Disease (AD) or being Cognitively Normal (CN).
    • Performance: Achieved 97.2% accuracy on the validation dataset.

๐Ÿ’ป Getting Started

To run the code, please check the implementation of the ADNISwiFTDataset class in the SWIN.py file. The dataset expects a CSV file with the following columns: ID, subject, fMRI_path, targets. Please create such a file first, then run the script with the generate_data config option set to True.

Once the data preprocessing is done, you can train the model using the following command:

conda env create -n swin_env -f configs/env.yaml
conda activate swin_env
pip install -r requirements.txt
python SWIN.py # You can choose between SWIN.py (default) or SWIN_CLIP.py (CLIP version)

Optional arguments for the training script:

  -h, --help     show this help message and exit
  --task TASK    Task to run (age_group or sex)
  --cuda CUDA    CUDA device to use (e.g., 0 for GPU 0)
  --wandb WANDB  Enable Weights and Biases (WandB) tracking

Visualizations can be generated by executing the following files:

python -m results.visualization.viz_maps
python -m results.visualization.viz_gradients

๐Ÿณ Docker

If you prefer running the project using Docker, a Dockerfile is provided for launching a FastAPI backend server, accessible at http://localhost:8000. You can build and run the container with the following commands:

docker build -t swin_app .
docker run -p 8000:8000 swin_app

๐Ÿงฉ Docker Compose (Frontend + Backend)

A docker-compose.yaml file is also included, which sets up both the backend (FastAPI server) and a frontend interface (served with Nginx). You can launch the full application with the following command:

docker compose up --build

After the services are up, you can access the application through your browser at http://localhost:80. This provides a nice interface to directly interact with the trained model.

docker_demo.mp4

๐Ÿ“Š Visualizations

We have integrated visualization capabilities to better understand the model's performance and the underlying data structure.

1. Age group

These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between age group classes.

t-SNE age group - Training set t-SNE: Training Set t-SNE age group - Validation set t-SNE: Validation Set t-SNE age group - Test set t-SNE: Test Set
UMAP age group - Testing set UMAP: Training Set UMAP age group - Validation set UMAP: Validation Set UMAP age group - Test set UMAP: Test Set

2. Sex

These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between gender classes.

t-SNE sex - Training set t-SNE: Training Set t-SNE sex - Validation set t-SNE: Validation Set t-SNE sex - Test set t-SNE: Test Set
UMAP sex - Training set UMAP: Training Set UMAP sex - Validation set UMAP: Validation Set UMAP sex - Test set UMAP: Test Set

3. Four-Target Classification (Age and Sex Combined)

These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between the four combined classes: Young Female, Young Male, Old Female, and Old Male.

t-SNE 4-target - Training set t-SNE: Training Set t-SNE 4-target - Validation set t-SNE: Validation Set t-SNE 4-target - Test set t-SNE: Test Set
UMAP 4-target - Training set UMAP: Training Set UMAP 4-target - Validation set UMAP: Validation Set UMAP 4-target - Test set UMAP: Test Set

4. Alzheimer's Disease

These plots show the projection of the fMRI features learned by the Swin 4D model into a 2D space, helping to visualize the separation between Alzheimer's Disease (AD) and Cognitively Normal (CN) subjects.

t-SNE AD - Training set t-SNE: Training Set t-SNE AD - Validation set t-SNE: Validation Set t-SNE AD - Test set t-SNE: Test Set
UMAP AD - Training set UMAP: Training Set UMAP AD - Validation set UMAP: Validation Set UMAP AD - Test set UMAP: Test Set

About

A 4D encoder for fMRI data

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages