A PyTorch implementation of ClipPrompt based on CVPR 2023 paper CLIP for All Things Zero-Shot Sketch-Based Image Retrieval, Fine-Grained or Not.
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install -c conda-forge torchmetrics
pip install git+https://github.com/openai/CLIP.git
Sketchy Extended and TU-Berlin Extended datasets are used in this repo, you could download these datasets from official websites, or download them from Google Drive. The data directory structure is shown as follows:
├──sketchy
├── train
├── sketch
├── airplane
├── n02691156_58-1.jpg
└── ...
...
├── photo
same structure as sketch
├── val
same structure as train
...
├──tuberlin
same structure as sketchy
...
To train a model on Sketchy Extended
dataset, run:
python main.py --mode train --data_name sketchy
To test a model on Sketchy Extended
dataset, run:
python main.py --mode test --data_name sketchy --query_name <query image path>
common arguments:
--data_root Datasets root path [default value is '/home/data']
--data_name Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin'])
--prompt_num Number of prompt embedding [default value is 3]
--save_root Result saved root path [default value is 'result']
--mode Mode of the script [default value is 'train'](choices=['train', 'test'])
train arguments:
--batch_size Number of images in each mini-batch [default value is 64]
--epochs Number of epochs over the model to train [default value is 60]
--triplet_margin Margin of triplet loss [default value is 0.3]
--encoder_lr Learning rate of encoder [default value is 1e-4]
--prompt_lr Learning rate of prompt embedding [default value is 1e-3]
--cls_weight Weight of classification loss [default value is 0.5]
--seed Random seed (-1 for no manual seed) [default value is -1]
test arguments:
--query_name Query image path [default value is '/home/data/sketchy/val/sketch/cow/n01887787_591-14.jpg']
--retrieval_num Number of retrieved images [default value is 8]
The models are trained on one NVIDIA GeForce RTX 3090 (24G) GPU. seed
is 42
, prompt_lr
is 1e-3
and distance function
is 1.0 - F.cosine_similarity(x, y)
, the other hyperparameters are the default values.
Dataset | Prompt Num | mAP@200 | mAP@all | P@100 | P@200 | Download |
---|---|---|---|---|---|---|
Sketchy Extended | 3 | 71.9 | 64.3 | 70.8 | 68.1 | MEGA |
TU-Berlin Extended | 3 | 75.3 | 66.0 | 73.9 | 69.7 | MEGA |