SAM-RSP: A New Few-Shot Segmentation Method Based on Segment Anything Model and Rough Segmentation Prompts
Abstract: Few-shot segmentation (FSS) aims to segment novel classes with a few labeled images. The backbones used in existing methods are pre-trained through classification tasks on the ImageNet dataset. Although these backbones can effectively perceive the semantic categories of images, they cannot accurately perceive the regional boundaries within one image, which limits the model performance. Recently, Segment Anything Model (SAM) has achieved precise image segmentation based on point or box prompts, thanks to its excellent perception of region boundaries within one image. However, it cannot effectively provide semantic information of images. This paper proposes a new few-shot segmentation method that can effectively perceive both semantic categories and regional boundaries. This method first utilizes the SAM encoder to perceive regions and obtain the query embedding. Then the support and query images are input into a backbone pre-trained on ImageNet to perceive semantics and generate a rough segmentation prompt (RSP). This query embedding is combined with the prompt to generate a pixel-level query prototype, which can better match the query embedding. Finally, the query embedding, prompt, and prototype are combined and input into the designed multi-layer prompt transformer decoder, which is more efficient and lightweight, and can provide a more accurate segmentation result. In addition, other methods can be easily combined with our framework to improve their performance. Plenty of experiments on PASCAL-5i and COCO-20i under 1-shot and 5-shot settings prove the effectiveness of our method. Our method also achieves new state-of-the-art.
- RTX 3090
- Python 3.8
- PyTorch 1.12.0
- cuda 11.6
- torchvision 0.13.0
- tensorboardX 2.2
-
COCO-20i: COCO2014
-
Put the datasets into the
data/
directory. -
Run
util/get_mulway_base_data.py
to generate base annotations and put them into thedata/base_annotation/
directory. (Only used when the coarse segmentation prompt generator is BAM.)
- Download the pre-trained VGG16 and ResNet50 encoders and put them into the
initmodel/
directory. - Download the pre-trained SAM encoder and put it into the
initmodel/SAM_encoder/
directory. - Download the pre-trained BAM models and put them into the
initmodel/BAM_models/
directory. (Only used when the coarse segmentation prompt generator is BAM.) - Download the pre-trained base learners from BAM and put them under
initmodel/PSPNet/
. (Only used when the coarse segmentation prompt generator is BAM.)
- Change configuration and add weight path to
.yaml
files inconfig
directory, then run thetrain.sh
file for training ortest.sh
file for testing.
Performance comparison with the state-of-the-art approachs in terms of average mIoU across all folds.
-
Backbone Method 1-shot 5-shot VGG16 MIANet 67.10 71.99 SAM-RSP(ours) 69.29 (+2.19) 73.86 (+1.87) ResNet50 HDMNet 69.40 71.80 SAM-RSP(ours) 70.76 (+1.36) 74.15 (+2.35) -
Backbone Method 1-shot 5-shot VGG16 HDMNet 45.90 52.40 SAM-RSP(ours) 48.79 (+2.89) 54.15 (+1.75) ResNet50 MIANet 47.66 51.65 SAM-RSP(ours) 49.84 (+2.18) 55.38 (+3.73)
This repo is mainly built based on SAM, and BAM. Thanks for their great work!
This paper has been accepted by the Image and Vision Computing journal.