for Driver Version: 580.95.05 CUDA Version: 13.0, we get our requirements.
pip install -r requirements.txt
The training data consists of 4 types of data:
-
Semantic segmentation datasets: ADE20K, COCO-Stuff, Mapillary, PACO-LVIS, PASCAL-Part, COCO Images for PASCAL-Part, VOCtrainval_03-May-2010.tar
Note: Add a link after the author's statement.
-
Referring segmentation datasets: refCOCO, refCOCO+, refCOCOg, refCLEF (saiapr_tc-12)
And you need to get train.json file.
python utils/pascal_part_mat2json.py --split_path='Path_to_VOCdevkit/VOC2010/ImageSets/Main' --split train.txt --ann_out Path_to_pascal_part/train.json --ann_path Path_to_pascal_part/Annotations_Part
Note: the original links of refCOCO series data are down, and we update them with new ones. If the download speed is super slow or unstable, we also provide a [OneDrive link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155154502_link_cuhk_edu_hk/Em5yELVBvfREodKC94nOFLoBLro_LPxsOxNV44PHRWgLcA?e=zQPjsc) to download. **You must also follow the rules that the original datasets require.**
-
Visual Question Answering dataset: LLaVA-Instruct-150k
-
Reasoning segmentation dataset: ReasonSeg
Download them from the above links, and organize them as follows.
├── dataset
│ ├── ade20k
│ │ ├── annotations
│ │ └── images
│ ├── coco
│ │ └── train2017
│ │ ├── 000000000009.jpg
│ │ └── ...
│ ├── cocostuff
│ │ └── train2017
│ │ ├── 000000000009.png
│ │ └── ...
│ ├── llava_dataset
│ │ └── llava_instruct_150k.json
│ ├── mapillary
│ │ ├── config_v2.0.json
│ │ ├── testing
│ │ ├── training
│ │ └── validation
│ ├── reason_seg
│ │ └── ReasonSeg
│ │ ├── train
│ │ ├── val
│ │ └── explanatory
│ ├── refer_seg
│ │ ├── images
│ │ | ├── saiapr_tc-12
│ │ | └── mscoco
│ │ | └── images
│ │ | └── train2014
│ │ ├── refclef
│ │ ├── refcoco
│ │ ├── refcoco+
│ │ └── refcocog
│ └── vlpart
│ ├── paco
│ │ └── annotations
│ └── pascal_part
│ ├── train.json
│ └── VOCdevkit
To train LISA-7B or 13B, you need to follow the instruction to merge the LLaVA delta weights. Typically, we use the final weights LLaVA-Lightning-7B-v1-1 and LLaVA-13B-v1-1 merged from liuhaotian/LLaVA-Lightning-7B-delta-v1-1 and liuhaotian/LLaVA-13b-delta-v1-1, respectively. For Llama2, we can directly use the LLaVA full weights liuhaotian/llava-llama-2-13b-chat-lightning-preview.
Download SAM ViT-H pre-trained weights from the link.
Origin:
deepspeed --master_port=24999 train_ds.py \
--version="PATH_TO_LLaVA" \
--dataset_dir='./dataset' \
--vision_pretrained="PATH_TO_SAM" \
--dataset="sem_seg||refer_seg||vqa||reason_seg" \
--sample_rates="9,3,3,1" \
--exp_name="lisa-7b"
Ours:
deepspeed --include localhost:6,7 --master_port=24999 train_ds.py \
--version="/data/lhy_data/LISA/LLaVA/llava-v1.5-7b" \
--load_in_8bit \
--epochs 10 \
--steps_per_epoch 500 \
--dataset="sem_seg||refer_seg||vqa||reason_seg" \
--sem_seg_data="ade20k||cocostuff" \
--exp_name="lisa-7b-lhy" \
--vis_save_path="/data/lhy_data/LISA/vis_output" \
--vision-tower="/data/lhy_data/LISA/clip-vit-large-patch14" \
--dataset_dir="/data/lhy_data/LISA/datasets" \
--log_base_dir="/data/lhy_data/LISA/runs" \
--vision_pretrained="/data/lhy_data/LISA/SAM/sam_vit_h_4b8939.pth" \
--sample_rates="9,3,3,1"
When training is finished, to get the full model weight:
cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
Merge the LoRA weights of pytorch_model.bin, save the resulting model into your desired path in the Hugging Face format:
CUDA_VISIBLE_DEVICES="" python merge_lora_weights_and_save_hf_model.py \
--version="PATH_TO_LLaVA" \
--weight="PATH_TO_pytorch_model.bin" \
--save_path="PATH_TO_SAVED_MODEL"
For example:
CUDA_VISIBLE_DEVICES="" python3 merge_lora_weights_and_save_hf_model.py \
--version="./LLaVA/LLaVA-Lightning-7B-v1-1" \
--weight="lisa-7b/pytorch_model.bin" \
--save_path="./LISA-7B"
deepspeed --master_port=24999 train_ds.py \
--version="PATH_TO_LISA_HF_Model_Directory" \
--dataset_dir='./dataset' \
--vision_pretrained="PATH_TO_SAM" \
--exp_name="lisa-7b" \
--eval_only
Note: the v1 model is trained using both train+val sets, so please use the v0 model to reproduce the validation results. (To use the v0 models, please first checkout to the legacy version repo with git checkout 0e26916.)
To chat with LISA-13B-llama2-v1 or LISA-13B-llama2-v1-explanatory:
(Note that chat.py currently does not support v0 models (i.e., LISA-13B-llama2-v0 and LISA-13B-llama2-v0-explanatory), if you want to use the v0 models, please first checkout to the legacy version repo git checkout 0e26916.)
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1'
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1-explanatory'
To use bf16 or fp16 data type for inference:
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='bf16'
To use 8bit or 4bit data type for inference (this enables running 13B model on a single 24G or 12G GPU at some cost of generation quality):
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='fp16' --load_in_8bit
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='fp16' --load_in_4bit
Hint: for 13B model, 16-bit inference consumes 30G VRAM with a single GPU, 8-bit inference consumes 16G, and 4-bit inference consumes 9G.
After that, input the text prompt and then the image path. For example,
- Please input your prompt: Where can the driver see the car speed in this image? Please output segmentation mask.
- Please input the image path: imgs/example1.jpg
- Please input your prompt: Can you segment the food that tastes spicy and hot?
- Please input the image path: imgs/example2.jpg
CUDA_VISIBLE_DEVICES=0 python app.py --version='xinlai/LISA-13B-llama2-v1 --load_in_4bit'
CUDA_VISIBLE_DEVICES=0 python app.py --version='xinlai/LISA-13B-llama2-v1-explanatory --load_in_4bit'
By default, we use 4-bit quantization. Feel free to delete the --load_in_4bit argument for 16-bit inference or replace it with --load_in_8bit argument for 8-bit inference.
In ReasonSeg, we have collected 1218 images (239 train, 200 val, and 779 test). The training and validation sets can be download from this link.
Each image is provided with an annotation JSON file:
image_1.jpg, image_1.json
image_2.jpg, image_2.json
...
image_n.jpg, image_n.json
Important keys contained in JSON files:
- "text": text instructions.
- "is_sentence": whether the text instructions are long sentences.
- "shapes": target polygons.
The elements of the "shapes" exhibit two categories, namely "target" and "ignore". The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process.
We provide a script that demonstrates how to process the annotations:
python3 utils/data_processing.py
Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have more than one instructions (but fewer than six) in the "text" field. During training, users may randomly select one as the text query to obtain a better model.
If you find this project useful in your research, please consider citing:
@article{lai2023lisa,
title={LISA: Reasoning Segmentation via Large Language Model},
author={Lai, Xin and Tian, Zhuotao and Chen, Yukang and Li, Yanwei and Yuan, Yuhui and Liu, Shu and Jia, Jiaya},
journal={arXiv preprint arXiv:2308.00692},
year={2023}
}
@article{yang2023improved,
title={An Improved Baseline for Reasoning Segmentation with Large Language Model},
author={Yang, Senqiao and Qu, Tianyuan and Lai, Xin and Tian, Zhuotao and Peng, Bohao and Liu, Shu and Jia, Jiaya},
journal={arXiv preprint arXiv:2312.17240},
year={2023}
}