Skip to content

[ECCV 2024] Soft Prompt Generation for Domain Generalization

License

Notifications You must be signed in to change notification settings

renytek13/Soft-Prompt-Generation

Repository files navigation

Soft Prompt Generation [ECCV 2024]

PWC PWC PWC PWC PWC

arXiv

Official implementation of the paper "Soft Prompt Generation for Domain Generalization".

Authors: Shuanghao Bai*, Yuedi Zhang*, Wanqi Zhou, Zhirong Luan, Badong Chen.


🎉 Highlights

Logo

Abstract: Large pre-trained vision language models (VLMs) have shown impressive zero-shot ability on downstream tasks with manually designed prompt. To further adapt VLMs to downstream tasks, soft prompt is proposed to replace manually designed prompt, which undergoes finetuning based on specific domain data. Prior prompt learning methods primarily learn a fixed prompt or residuled prompt from training samples. However, the learned prompts lack diversity and ignore information about unseen domains. In this paper, we reframe the prompt learning framework from a generative perspective and propose a simple yet efficient method for the Domain Generalization (DG) task, namely Soft Prompt Generation (SPG). Specifically, SPG consists of a two-stage training phase and an inference phase. During the training phase, we introduce soft prompt label for each domain, aiming to incorporate the generative model domain knowledge. During the inference phase, the generator of the generative model is employed to obtain instance-specific soft prompts for the unseen target domain. Extensive experiments on five domain generalization benchmarks of three DG tasks demonstrate that SPG achieves state-of-the-art performance.

Main Contributions
  1. To the best of our knowledge, we are the first to introduce the generative model into prompt learning in VLMs. Then, we propose a new paradigm of prompt tuning, namely Soft Prompt Generation (SPG).
  2. We design a two-stage training phase to align the generative model with domain prompt labels. It incorporates domain knowledge into the generated prompts, enhancing the transferability across unseen domains.
  3. Extensive experiments on five datasets for three DG tasks demonstrate that the proposed SPG achieves state-of-the-art performance

🛠️ Installation

For installation and other package requirements, please follow the instructions as follows. This codebase is tested on Ubuntu 20.04 LTS with python 3.8. Follow the below steps to create environment and install dependencies.

  • Setup conda environment.
# Create a conda environment
conda create -y -n spg python=3.8

# Activate the environment
conda activate spg

# Install torch (requires version >= 1.8.1) and torchvision
# Please refer to https://pytorch.org/get-started/previous-versions/ if your cuda version is different
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
  • Install dassl library.
# Instructions borrowed from https://github.com/KaiyangZhou/Dassl.pytorch#installation

# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch

# Install dependencies
pip install -r requirements.txt

# Install this library (no need to re-build if the source code is modified)
python setup.py develop
cd ..
  • Clone SPG code repository and install requirements.
# Clone SPG code base
git clone https://github.com/renytek13/Soft-Prompt-Generation.git
cd Soft-Prompt-Generation

# Install requirements
pip install -r requirements.txt

📁 Data Preparation

Please download the datasets PACS, VLCS, Office-Home, TerraIncognita, and DomainNet.

Follow DATASETS.md to install the datasets.

📈 Training and Evaluation

We provide the running scripts in scripts, which allow you to reproduce the results on the paper. Make sure you modify the path in $DATA!

Training Stage I: Domain Prompt Labels Learning (Optional)

We provid our dataset split files spg_coop_splits and the trained domain prompt labels.

If you wanna use our produced data splits and domain prompt labels, please put split files into the dataset directory (follow DATASETS.md) and put domain prompt labels in domain prompt labels. Then go to the Training Stage II: Generative Model Pre-training

If you wanna use the data splits and domain prompt labels produced by yourself. Please follow the instructions below.

To obtain data splits and domain prompt labels, please run the bash file in scripts folder as follows.

# Example: trains on PACS dataset with ResNet50 as the backbone, and the gpu id is 0. 
bash scripts/spg_coop/spg_coop.sh pacs RN50 0

Training Stage II: Generative Model Pre-training

Please refer to DATASETS.md, and make sure that our produced data splits are in your data path. The bash files of three types of DG tasks in scripts folder.

For multi-source Domain Generalization

# Example: trains on PACS dataset with ResNet50 as the backbone, and the gpu id is 0. 
bash scripts/spg_cgan/spg_cgan.sh pacs spg RN50 0

For Single-source Domain Generation

# Example: trains on VLCS dataset with ResNet50 as the backbone, and the gpu id is 1. 
bash scripts/spg_cgan/single.sh vlcs spg RN50 1

For Cross-dataset Domain Generation

# Example: trains on DomainNet dataset with ViT-B/16 as the backbone, and the gpu id is 2. 
bash scripts/spg_cgan/cross.sh spg ViT-B/16 2

Evaluation

For multi-source Domain Generalization

# Example: test PACS dataset with ResNet50 as the backbone, and the gpu id is 0. 
bash scripts/test_all.sh pacs spg RN50 0

📊 Supported Methods

Supported methods in this codespace are as follows:

Method Paper Code
CoOp IJCV 2022 link
CoCoOp CVPR 2022 link
VP - link
VPT ECCV 2022 link
MaPLe CVPR 2023 link
DPL TJSAI 2023 link

Also, for our SPG method, we provide our pre-trained models on five DG datasets and you can directly evaluate on those models.

📝 Citation

If our code is helpful to your research or projects, please consider citing our work!

@inproceedings{bai2024soft,
  title={Soft Prompt Generation for Domain Generalization},
  author={Bai, Shuanghao and Zhang, Yuedi and Zhou, Wanqi and Luan, Zhirong and Chen, Badong},
  booktitle={European Conference on Computer Vision},
  year={2024}
}

📨 Contact

If you have any questions, please create an issue on this repository or contact us at zyd993@stu.xjtu.edu.cn or baishuanghao@stu.xjtu.edu.cn.

🙏 Acknowledgements

Our code is based on CoOp and CoCoOp, MaPLe, and PDA repository. We thank the authors for releasing their codes. If you use their codes, please consider citing these works as well.