From a7d634fcc1dd0af2512f607e2bae416e3c0b0a2f Mon Sep 17 00:00:00 2001 From: "Teo (Timothy) Wu Haoning" <38696372+teowu@users.noreply.github.com> Date: Sun, 5 Nov 2023 03:34:27 +0800 Subject: [PATCH] Q-Bench Evaluation Scripts for LLaVA-v1.5. (#581) * add evaluation instructions * add evaluation instructions * evaluation scripts for Q-Bench * evaluation scripts for Q-Bench * evaluation scripts for Q-Bench * evaluation code for qbench * Add more benchmarks section --------- Co-authored-by: Haotian Liu <6631389+haotian-liu@users.noreply.github.com> --- docs/Evaluation.md | 25 +++++++ llava/eval/model_vqa_qbench.py | 122 +++++++++++++++++++++++++++++++++ scripts/v1_5/eval/qbench.sh | 18 +++++ scripts/v1_5/eval/qbench_zh.sh | 20 ++++++ 4 files changed, 185 insertions(+) create mode 100644 llava/eval/model_vqa_qbench.py create mode 100644 scripts/v1_5/eval/qbench.sh create mode 100644 scripts/v1_5/eval/qbench_zh.sh diff --git a/docs/Evaluation.md b/docs/Evaluation.md index 95ff557f3..3e46a98ce 100644 --- a/docs/Evaluation.md +++ b/docs/Evaluation.md @@ -114,6 +114,7 @@ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh ``` 3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_cn_20231003`. + ### SEED-Bench 1. Following the official [instructions](https://github.com/AILab-CVC/SEED-Bench/blob/main/DATASET.md) to download the images and the videos. Put images under `./playground/data/eval/seed_bench/SEED-Bench-image`. @@ -140,3 +141,27 @@ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/llavabench.sh CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmvet.sh ``` 3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook. + +## More Benchmarks + +Below are awesome benchmarks for multimodal understanding from the research community, that are not initially included in the LLaVA-1.5 release. + +### Q-Bench + +1. Download [`llvisionqa_dev.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_dev.json) (for `dev`-subset) and [`llvisionqa_test.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_test.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`. +2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`. +3. Single-GPU inference (change `dev` to `test` for evaluation on test set). +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench.sh dev +``` +4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_dev_answers.jsonl`. + +### Chinese-Q-Bench + +1. Download [`质衡-问答-验证集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E9%AA%8C%E8%AF%81%E9%9B%86.json) (for `dev`-subset) and [`质衡-问答-测试集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E6%B5%8B%E8%AF%95%E9%9B%86.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`. +2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`. +3. Single-GPU inference (change `dev` to `test` for evaluation on test set). +```Shell +CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench_zh.sh dev +``` +4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_zh_dev_answers.jsonl`. diff --git a/llava/eval/model_vqa_qbench.py b/llava/eval/model_vqa_qbench.py new file mode 100644 index 000000000..f3ca8177c --- /dev/null +++ b/llava/eval/model_vqa_qbench.py @@ -0,0 +1,122 @@ +import argparse +import torch +from tqdm import tqdm +import json + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO + + +def load_image(image_file): + if image_file.startswith('http') or image_file.startswith('https'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def eval_model(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True) + + + + + with open(args.questions_file) as f: + llvqa_data = json.load(f) + + for i, llddata in enumerate(tqdm(llvqa_data)): + filename = llddata["img_path"] + if args.lang == "en": + message = llddata["question"] + "\nChoose between one of the options as follows:\n" + elif args.lang == "zh": + message = llddata["question"] + "\在下列选项中选择一个:\n" + else: + raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.") + for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]): + message += f"{choice} {ans}\n" + qs = message + + if model.config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + if 'llama-2' in model_name.lower(): + conv_mode = "llava_llama_2" + elif "v1" in model_name.lower(): + conv_mode = "llava_v1" + elif "mpt" in model_name.lower(): + conv_mode = "mpt" + else: + conv_mode = "llava_v0" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + else: + args.conv_mode = conv_mode + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + image = load_image(args.image_folder + filename) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor, + num_beams=1, + do_sample=False, + temperature=0, + max_new_tokens=1024, + use_cache=True, + stopping_criteria=[stopping_criteria]) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + llddata["response"] = outputs + with open(args.answers_file, "a") as wf: + json.dump(llddata, wf) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="llava-v1.5") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa") + parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_v1") + parser.add_argument("--lang", type=str, default="en") + args = parser.parse_args() + + eval_model(args) diff --git a/scripts/v1_5/eval/qbench.sh b/scripts/v1_5/eval/qbench.sh new file mode 100644 index 000000000..46b8e029b --- /dev/null +++ b/scripts/v1_5/eval/qbench.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +if [ "$1" = "dev" ]; then + echo "Evaluating in 'dev' split." +elif [ "$1" = "test" ]; then + echo "Evaluating in 'test' split." +else + echo "Unknown split, please choose between 'dev' and 'test'." + exit 1 +fi + +python -m llava.eval.model_vqa_qbench \ + --model-path liuhaotian/llava-v1.5-13b \ + --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \ + --questions-file ./playground/data/eval/qbench/llvisionqa_$1.json \ + --answers-file ./playground/data/eval/qbench/llvisionqa_$1_answers.jsonl \ + --conv-mode llava_v1 \ + --lang en diff --git a/scripts/v1_5/eval/qbench_zh.sh b/scripts/v1_5/eval/qbench_zh.sh new file mode 100644 index 000000000..7bfc17088 --- /dev/null +++ b/scripts/v1_5/eval/qbench_zh.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +if [ "$1" = "dev" ]; then + ZH_SPLIT="验证集" + echo "Evaluating in 'dev' split." +elif [ "$1" = "test" ]; then + ZH_SPLIT="测试集" + echo "Evaluating in 'test' split." +else + echo "Unknown split, please choose between 'dev' and 'test'." + exit 1 +fi + +python -m llava.eval.model_vqa_qbench \ + --model-path liuhaotian/llava-v1.5-13b \ + --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \ + --questions-file ./playground/data/eval/qbench/质衡-问答-$ZH_SPLIT.json \ + --answers-file ./playground/data/eval/qbench/llvisionqa_zh_$1_answers.jsonl \ + --conv-mode llava_v1 \ + --lang zh