This is the official implementation of SLM-SAM 2 (IEEE TMI).
By Yuwen Chen, Zafer Yildiz, Qihang Li, Yaqian Chen, Haoyu Dong, Hanxue Gu, Nicholas Konz, Maciej A. Mazurowski
🎉 Our paper has been accepted to IEEE Transactions on Medical Imaging (IEEE TMI)!
SLM-SAM 2 is a novel video object segmentation method that can accelerate volumetric medical image annotation by propagating annotations from a single slice to the remaining slices within volumes. By introducing a dynamic short-long memory module, SLM-SAM 2 shows improved segmentation performance on organs, bones and muscles across different imaging modalities (MRI, CT, Ultrasound videos) than SAM 2.
Firstly, please install PyTorch and TorchVision dependencies following instructions here. SLM-SAM 2 can be installed using:
cd SLM-SAM 2
pip install -e .Before finetuning, we need to download SAM 2 pretrained checkpoints using following commands:
cd checkpoints && \
./download_ckpts.sh && \
cd ..Open ./sam2/configs/sam2.1_training/slm_sam2_hiera_t_finetune.yaml, add path to image folder, mask folder, and text file describing volumes used for training. The dataset format follows the same as that of SAM 2.
DATA_DIRECTORY
├── images
│ ├── volume1
│ │ ├── 00000.jpg
│ │ ├── 00001.jpg
│ │ └── ...
│ └── ...
├── masks
│ ├── volume1
│ │ ├── 00000.png
│ │ ├── 00001.png
│ │ └── ...
│ └── ...
├── train.txt
├── test.txt
Start finetuning by running:
CUDA_VISIBLE_DEVICES=[GPU_ID] python3 training/train.py \
-c configs/sam2.1_training/slm_sam2_hiera_t_finetune.yaml \
--use-cluster 0 \
--num-gpus 1
Propagate annotation by running:
CUDA_VISIBLE_DEVICES=[GPU_ID] python3 inference.py \
--test_img_folder [test image folder path] \
--test_mask_folder [test mask folder path] \
--checkpoint_folder [checkpoint path] \
--checkpoint_name [checkpoint file name] \
--cfg_name slm_sam2_hiera_t.yaml \
--test_txt_file [test text file path] \
--mask_prompt_dict [path to mask prompt dictionary] \
--output_folder [path of output folder, to save predictions] \- checkpoint_folder: directory that contains .pt file
- checkpoint_name: name of .pt file
- mask_prompt_dict: dictionary mapping each volume ID to the slice index used as the mask prompt (e.g., mask_prompt_dict[volume_id] = slice_index)
All codes in this repository are under GPLv3 license.


