Official implementation of the paper "Soft Prompt Generation for Domain Generalization".
Authors: Shuanghao Bai*, Yuedi Zhang*, Wanqi Zhou, Zhirong Luan, Badong Chen.
Abstract: Large pre-trained vision language models (VLMs) have shown impressive zero-shot ability on downstream tasks with manually designed prompt. To further adapt VLMs to downstream tasks, soft prompt is proposed to replace manually designed prompt, which undergoes finetuning based on specific domain data. Prior prompt learning methods primarily learn a fixed prompt or residuled prompt from training samples. However, the learned prompts lack diversity and ignore information about unseen domains. In this paper, we reframe the prompt learning framework from a generative perspective and propose a simple yet efficient method for the Domain Generalization (DG) task, namely Soft Prompt Generation (SPG). Specifically, SPG consists of a two-stage training phase and an inference phase. During the training phase, we introduce soft prompt label for each domain, aiming to incorporate the generative model domain knowledge. During the inference phase, the generator of the generative model is employed to obtain instance-specific soft prompts for the unseen target domain. Extensive experiments on five domain generalization benchmarks of three DG tasks demonstrate that SPG achieves state-of-the-art performance.
Main Contributions
- To the best of our knowledge, we are the first to introduce the generative model into prompt learning in VLMs. Then, we propose a new paradigm of prompt tuning, namely Soft Prompt Generation (SPG).
- We design a two-stage training phase to align the generative model with domain prompt labels. It incorporates domain knowledge into the generated prompts, enhancing the transferability across unseen domains.
- Extensive experiments on five datasets for three DG tasks demonstrate that the proposed SPG achieves state-of-the-art performance
For installation and other package requirements, please follow the instructions as follows. This codebase is tested on Ubuntu 20.04 LTS with python 3.8. Follow the below steps to create environment and install dependencies.
- Setup conda environment.
# Create a conda environment
conda create -y -n spg python=3.8
# Activate the environment
conda activate spg
# Install torch (requires version >= 1.8.1) and torchvision
# Please refer to https://pytorch.org/get-started/previous-versions/ if your cuda version is different
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
- Install dassl library.
# Instructions borrowed from https://github.com/KaiyangZhou/Dassl.pytorch#installation
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch
# Install dependencies
pip install -r requirements.txt
# Install this library (no need to re-build if the source code is modified)
python setup.py develop
cd ..
- Clone SPG code repository and install requirements.
# Clone SPG code base
git clone https://github.com/renytek13/Soft-Prompt-Generation.git
cd Soft-Prompt-Generation
# Install requirements
pip install -r requirements.txt
Please download the datasets PACS
, VLCS
, Office-Home
, TerraIncognita
, and DomainNet
.
Follow DATASETS.md to install the datasets.
We provide the running scripts in scripts
, which allow you to reproduce the results on the paper.
Make sure you modify the path in $DATA
!
We provid our dataset split files spg_coop_splits and the trained domain prompt labels.
If you wanna use our produced data splits and domain prompt labels, please put split files into the dataset directory (follow DATASETS.md) and put domain prompt labels in domain prompt labels. Then go to the Training Stage II: Generative Model Pre-training
If you wanna use the data splits and domain prompt labels produced by yourself. Please follow the instructions below.
To obtain data splits and domain prompt labels, please run the bash file in scripts folder as follows.
# Example: trains on PACS dataset with ResNet50 as the backbone, and the gpu id is 0.
bash scripts/spg_coop/spg_coop.sh pacs RN50 0
Please refer to DATASETS.md, and make sure that our produced data splits are in your data path. The bash files of three types of DG tasks in scripts folder.
For multi-source Domain Generalization
# Example: trains on PACS dataset with ResNet50 as the backbone, and the gpu id is 0.
bash scripts/spg_cgan/spg_cgan.sh pacs spg RN50 0
For Single-source Domain Generation
# Example: trains on VLCS dataset with ResNet50 as the backbone, and the gpu id is 1.
bash scripts/spg_cgan/single.sh vlcs spg RN50 1
For Cross-dataset Domain Generation
# Example: trains on DomainNet dataset with ViT-B/16 as the backbone, and the gpu id is 2.
bash scripts/spg_cgan/cross.sh spg ViT-B/16 2
For multi-source Domain Generalization
# Example: test PACS dataset with ResNet50 as the backbone, and the gpu id is 0.
bash scripts/test_all.sh pacs spg RN50 0
Supported methods in this codespace are as follows:
Method | Paper | Code |
---|---|---|
CoOp | IJCV 2022 | link |
CoCoOp | CVPR 2022 | link |
VP | - | link |
VPT | ECCV 2022 | link |
MaPLe | CVPR 2023 | link |
DPL | TJSAI 2023 | link |
Also, for our SPG method, we provide our pre-trained models on five DG datasets and you can directly evaluate on those models.
If our code is helpful to your research or projects, please consider citing our work!
@inproceedings{bai2024soft,
title={Soft Prompt Generation for Domain Generalization},
author={Bai, Shuanghao and Zhang, Yuedi and Zhou, Wanqi and Luan, Zhirong and Chen, Badong},
booktitle={European Conference on Computer Vision},
year={2024}
}
If you have any questions, please create an issue on this repository or contact us at zyd993@stu.xjtu.edu.cn or baishuanghao@stu.xjtu.edu.cn.
Our code is based on CoOp and CoCoOp, MaPLe, and PDA repository. We thank the authors for releasing their codes. If you use their codes, please consider citing these works as well.