PreMix: Label-Efficient Multiple Instance Learning via Non-Contrastive Pre-training and Feature Mixing
Under Submission Review for Computer Methods and Programs in Biomedicine
- Non-contrastive MIL pre-training for WSI classification: We introduce PreMix, a self-supervised framework at the WSI level based on Barlow Twins, which avoids reliance on negative pairs and effectively addresses class imbalance in WSI datasets.
- Intra-batch slide mixing for semantic enrichment: PreMix enhances representation learning through Barlow Twins Slide Mixing, generating additional positive pairs by interpolating features across slides within the same batch.
- Dual-stage mixing strategy for improved generalization: We combine unsupervised pre-training with supervised fine-tuning using Mixup and Manifold Mixup, demonstrating robust performance across diverse WSI datasets and label budgets.
To assess the robustness of both the original MIL framework and the proposed PreMix framework, we conducted experiments using random sampling (traditional fully supervised fine-tuning) and active learning settings with a limited WSI labeled dataset
- Windows 10 Enterprise
- 2 NVIDIA RTX 2080 Ti GPUs (11GB each)
- CUDA version: 11.7
- Python version: 3.8.16
Install Anaconda
Create a new environment and activate it
conda create --name premix python=3.8.16
conda activate premixInstall all required packages
pip install -r requirements.txt
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117Extract square regions from each WSI to create a pretrain dataset using HS2P, which tiles tissue and extracts relevant regions at a given pixel spacing
The results from HS2P should be structured as follows:
Folder structure
<data_dir>/
├── pretrain/
├── hs2p_20x_4096/
├── debug/
├── patches/
├── slide_1/
├── 4096
├── jpg
├── slide_1_x1_y1.png
├── slide_1_x2_y2.png
├── ...
slide_2/
├── 4096
├── jpg
├── slide_2_x1_y1.png
├── slide_2_x2_y2.png
├── ...
├── ...
Download HIPT pre-trained weights using the following commands:
Download commands
mkdir checkpoints
cd checkpoints
gdown 1Qm-_XrTMYhu9Hl-4FClaOMuroyWlOAxw
gdown 1A2eHTT0dedHgdCvy6t3d9HwluF8p5yjz
Create a configuration file under config/feature_extraction/ inspired by existing files
To extract region-level features, set level: 'global' in the config (refer to config/feature_extraction/global.yaml)
Ensure that slides_list.txt contains a list of all slide names in a .txt file:
slide_1
slide_2
...
Run the following command to initiate feature extraction:
python extract_features.py --config-name globalThe results should be structured as follow:
Folder structure
outputs/
├── pretrain/
├── features/
├── hipt/
├── global/
├── region/
├── slide_1_x1_y1.pt
├── slide_2_x2_y2.pt
├── ...
├── slide/
├── slide_1.pt
├── slide_2.pt
├── ...1. Prepapre a csv file inside data/pretrain/ (refer to data/pretrain/camelyon16_cptac_ucec.csv)
This csv list all the slides for pretraining
slide_id
slide_1
slide_2
...
2. Create a configuration file under config/training/ inspired by existing files
Refer to config/training/pretrain.yaml for inspiration
3. Pretrain Barlow Twins Slide Mixing
Run the following command to initiate Barlow Twins Slide Mixing:
python barlow_twins_slide_mixing.py --config-name pretrainThe results should be structured as follows:
Folder structure
outputs/
├── pretrain/
├── checkpoints/
├── global
├── <model_name>_<epoch>.pth
├── <model_name>_<epoch>.pth
├── ...1. Prepare pool and test csv files for downstream classification
Refer to data/camelyon16/pool.csv and data/camelyon16/test.csv for inspiration
These two csv files list all the slides for downstream classification
slide_id,label
slide_1,0
slide_2,1
...
2. Create a configuration file under config/training/ inspired by existing files
Refer to config/training/global.yaml for inspiration
Note that the <model_name> in the config file should be the full name <model_name>_<epoch>
Make sure to include the following to integrate slide mixing strategies during fine-tuning and uncomment if they are not needed
mixing:
mixup: True
manifold_mixup: True
manifold_mixup_transformer: True
mixup_alpha: 1
mixup_alpha_per_sample: False
mixup_type: random # [random, cosine_sim, class_aware]
...
3. Fine-tuning MIL aggregator with mixup and manifold mixup
Run the following command to initiate the fine-tuning process:
python main.py --config-name globalThe results should be structured as follows:
Folder structure
outputs/
├── <downstream_dataset>
├── checkpoints/
├── <all_settings_folder>
├── <AL_strategy>
├── best_model.pth
├── results/
├── <all_settings_folder>
├── <AL_strategy>
├── train_0.csv
├── train_1.csv
├── ...
├── test.csv
├── test_results.csv
├── roc_auc_curve.png
├── scripts/
├── <all_settings_folder>
├── log.txt
best_model.pth is the best model given the specified settings and AL strategy
train_0.csv, train_1.csv, etc., contain the predicted probability over classes
log.txt contains all the model's performance metrics (ACC, AUC, Precision, Recall) over all AL strategies and training labeled budgets
This codebase builds upon HIPT and Re-Implementation HIPT



