This is the official Pytorch implementation of our CVPR 2024 paper (Highlight) "Continual Self-supervised Learning: Towards Universal Multi-modal Medical Data Representation Learning".
CUDA 11.5
Python 3.8
Pytorch 1.11.0
CuDNN 8.3.2.44
- Pre-training data
- Download the MIMIC-CXR dataset.
- Download the DeepLesion dataset.
- Download the ADNI dataset.
- Download seven TCGA datasets (TCGA-THYM, TCGA-THCA, TCGA-BRCA, TCGA-UCEC, TCGA-UVM, TCGA-OV, and TCGA-MESO).
- Fine-tuning data
-
PudMed20k dataset: Download the PudMed20k dataset.
-
ChestXR dataset: Download the ChestXR dataset.
-
QaTav2 dataset: Download the QaTav2 dataset.
-
RICORD dataset: Download the MIDRC-RICORD-1A dataset and MIDRC-RICORD-1B dataset. The folder structure of the dataset should be like
dataset/RICORD_nii/ ├── MIDRC-RICORD-1A ├── MIDRC-RICORD-1B
-
LiTS dataset: Download the LiTS dataset.
-
VS dataset: Download the VS dataset.
-
LA dataset: Download the LA dataset.
-
NCH dataset: Download the NCT-CRC-HE-100K and CRC-VAL-HE-7K datasets.
-
GlaS dataset: Download the GlaS dataset.
-
- Pre-training data
- Report: Following MGCA's procedure to pre-process the MIMIC-CXR dataset.
- X-ray: Using
Preprocess/MIMIC_CXR_JPG_Preprocess.py
to pre-process the MIMC-CXR dataset. - CT:
- Using
Preprocess/DL_save_nifti.py
(from downloaded files) to transfer the PNG image to the nii.gz form. - Using
Preprocess/re_spacing_ITK.py
to resample CT volumes. - Using
Preprocess/splitting_to_patches.py
to extract about 125k sub-volumes, and the pre-processed dataset will be saved inDL_patches_v2/
. - Using
Preprocess/DeepLesion_Resize.py
to resize images.
- Using
- MRI:
- Using
Preprocess/ADNI_Resize.py
to resize images. - Using
Preprocess/ADNI_split_slice.py
to extract about 59k sub-volumes.
- Using
- Pathological imaging: Using
Preprocess/TCGA_Preprocess.py
to pre-process seven TCGA datasets.
- Fine-tuning data
- PudMed20k dataset: None.
- ChestXR dataset: None.
- QaTav2 dataset: Using
Preprocess/QaTav2.py
to pre-process. - RICORD dataset: Using
Preprocess/RICORD.py
to pre-process. Data Splits can be obtained from/Downstream/Dim_3/RICORD/data_split
. - LiTS dataset:
- (1) Resampling all data to the same spacing of 1.5mm × 0.8mm × 0.8mm;
- (2) Using the nnUNet v1 framework to pre-process.
- VS dataset:
- (1) Run Convert_VSseg_to_nnUNet_dataset.py;
- (2) Using the nnUNet v1 framework to pre-process;
- (3) Using
Preprocess/VSeg.py
to pre-process.
- LA dataset: None.
- NCH dataset: None.
- GlaS dataset: Using
Preprocess/GlaS.py
to pre-process.
- Download uni-perceiver-base-L12-H768-224size-torch-pretrained.pth.
- Run
sh run_ssl.sh
for pre-training (4 GPUs with 24G. Before running it, you need to modify some addresses.)
- Pre-trained model is available in MedCoSS_Report_Xray_CT_MR_Path_Buffer0.05.
- Run
sh run_ds.sh
for fine-tuning. (one GPU with 11G. Before running it, you need to modify some addresses.)
- Dataset Links
- Pre-processing Code
- Pre-training Code Release
- Pre-trained Model
- Fine-tuning Code Release
- Continual pre-training on new data
If this code is helpful for your study, please cite:
@article{ye2024medcoss,
title={Continual Self-supervised Learning: Towards Universal Multi-modal Medical Data Representation Learning},
author={Ye, Yiwen and Xie, Yutong and Zhang, Jianpeng and Chen, Ziyang and Wu, Qi and Xia, Yong},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
pages={11114-11124},
year={2024},
}
The whole framework is based on MAE, Uni-Perceiver, and MGCA.
Yiwen Ye (ywye@mail.nwpu.edu.cn)