Cross-Modal Translation and Alignment for Survival Analysis, ICCV 2023.
[arxiv] [link]
Fengtao ZHOU, Hao CHEN
@inproceedings{zhou2023cross,
title ={Cross-Modal Translation and Alignment for Survival Analysis},
author ={Zhou, Fengtao and Chen, Hao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages ={21485--21494},
year ={2023}
}
Summary: Here is the official implementation of the paper "Cross-Modal Translation and Alignment for Survival Analysis".
Please follow this GitHub for more updates.
- Address OOM issues (sampling certain number of patches for specific patients)
torch 1.12.0+cu116
scikit-survival 0.19.0
- Download diagnostic WSIs from TCGA
- Use the WSI processing tool provided by CLAM to extract resnet-50 pretrained 1024-dim feature for each 256
$\times$ 256 patch (20x), which we then save as.pt
files for each WSI. So, we get onept_files
folder storing.pt
files for all WSIs of one study.
The final structure of datasets should be as following:
DATA_ROOT_DIR/
└──pt_files/
├── slide_1.pt
├── slide_2.pt
└── ...
DATA_ROOT_DIR is the base directory of cancer type (e.g. the directory to TCGA_BLCA), which should be passed to the model with the argument --data_root_dir
as shown in run.sh.
In this work, we directly use the preprocessed genomic data provided by MCAT, stored in folder csv.
Splits for each cancer type are found in the splits/5foldcv
folder, which are randomly partitioned each dataset using 5-fold cross-validation. Each one contains splits_{k}.csv for k = 1 to 5. To compare with MCAT, we follow the same splits as that of MCAT.
To train CMTA, you can specify the argument in the bash run.sh
and run the command:
bash run.sh
or use the following generic command-line and specify the arguments:
CUDA_VISIBLE_DEVICES=<DEVICE_ID> python main.py \
--which_splits 5foldcv \
--dataset <CANCER_TYPE> \
--data_root_dir <DATA_ROOT_DIR>\
--modal coattn \
--model cmta \
--num_epoch 30 \
--batch_size 1 \
--loss nll_surv_l1 \
--lr 0.001 \
--optimizer SGD \
--scheduler None \
--alpha 1.0
Commands for all experiments of CMTA can be found in the run.sh file.
Tips: some patients may have multiple WSIs, especially in TCGA-GBMLGG, resulting in OOM issue. In such case, we can randomly sample certain number of patches for these special patients to reduce the computational requirements. That will not significantly impact the overall performance.
CUDA_VISIBLE_DEVICES=<DEVICE_ID> python main.py \
--which_splits 5foldcv \
--dataset <CANCER_TYPE> \
--data_root_dir <DATA_ROOT_DIR>\
--modal coattn \
--model cmta \
--num_epoch 30 \
--batch_size 1 \
--loss nll_surv_l1 \
--lr 0.001 \
--optimizer SGD \
--scheduler None \
--alpha 1.0 \
--OOM 4096
If the number of patches is larger than 4096, randomly sampling 4096 patches. If there is still OOM issue, you can further reduce the number of sampled patches.
Huge thanks to the authors of following open-source projects:
If you find our work useful in your research, please consider citing our paper at:
@inproceedings{zhou2023cross,
title ={Cross-Modal Translation and Alignment for Survival Analysis},
author ={Zhou, Fengtao and Chen, Hao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages ={21485--21494},
year ={2023}
}
This code is available for non-commercial academic purposes. If you have any question, feel free to email Fengtao ZHOU.