Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Update Molmo Eval to Match Official Implementation #648

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 141 additions & 5 deletions vlmeval/vlm/molmo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
import torch
from PIL import Image
import os.path as osp
import sys
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE

TYPE_PROMPTS = {
'Y/N':'vqa2:',
'VQA':'vqa2:',
'MCQ':'a_okvqa_mc:',
}

DATASET_PROMPTS = {
'AI2D_TEST':'ai2_diagram:',
'AI2D_TEST_NO_MASK':'ai2_diagram:',
'COCO_VAL':'coco_captioning:',
'ChartQA_TEST':'chart_qa:',
'ChartQA_VAL':'chart_qa:',
'DocVQA_VAL':'doc_qa:',
'DocVQA_TEST':'doc_qa:',
'InfoVQA_TEST':'info_qa:',
'InfoVQA_VAL':'info_qa:',
'OCRVQA_TEST':'ocr_vqa:',
'OCRVQA_TESTCORE':'ocr_vqa:',
'ScienceQA_VAL':'science_qa:',
'ScienceQA_TEST':'science_qa:',
'TableVQABench':'tabwmp_da:',
'TextVQA_VAL':'text_vqa:'
}


class molmo(BaseModel):

Expand Down Expand Up @@ -36,6 +58,106 @@ def __init__(self, model_path='allenai/Molmo-7B-D-0924', **kwargs):
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
self.kwargs = kwargs
self.model_name = model_path
# set default maximum number of crops to 36
self.max_crops = kwargs.get('max_crops', 36)

def use_custom_prompt(self, dataset):
if DATASET_TYPE(dataset) in ['Y/N', 'MCQ', 'VQA']:
return True
return False

def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)
prefix = None
if dataset in ['MMMU_DEV_VAL', 'MMMU_TEST']:
prompt = self.build_prompt_mcq_vqa(line)
elif dataset in ['MathVista_MINI']:
prompt = self.build_prompt_mathvista(line)
elif dataset in ['AI2D_TEST', 'AI2D_TEST_NO_MASK']:
prompt = self.build_prompt_ai2d(line)
elif dataset is not None and listinstr(list(DATASET_PROMPTS.keys()), dataset):
prefix = DATASET_PROMPTS[dataset] # rest of supervised datasets are in VQA format
prompt = self.build_prompt_vqa(line, prefix)
elif dataset is not None and listinstr(['MCQ'], DATASET_TYPE(dataset)):
prompt = self.build_prompt_multiple_choice(line)
else:
prompt = self.build_prompt_vqa(line)

message = [dict(type='text', value=prompt)]
message.extend([dict(type='image', value=s) for s in tgt_path])

# interleave dataset
if dataset.startswith('MMMU_'):
from .. import MMMUDataset
message = MMMUDataset.split_MMMU(message)
return message

def build_prompt_mathvista(self, line):
if line['question_type'] == 'multi_choice':
prompt = self.build_prompt_multiple_choice(line)
else:
prompt = self.build_prompt_vqa(line)
return prompt

def build_prompt_ai2d(self, line):
def option_is_abc(line):
for cand in string.ascii_uppercase:
if cand in line and not pd.isna(line[cand]):
# check if option is single letter
if not line[cand].strip().isalpha() or len(line[cand].strip()) > 1:
return False
return True

if line['abcLabel'] and option_is_abc(line):
prompt = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
prompt += f'\n{item}'
prompt = f"ai2_diagram_no_letter: {prompt}"
# prompt = self.build_prompt_multiple_choice(line, prefix='ai2_diagram_no_letter:')
else:
prompt = self.build_prompt_multiple_choice(line, prefix='ai2_diagram:')
return prompt

def build_prompt_mcq_vqa(self, line):
if line['question_type'] == 'multiple-choice':
prompt = self.build_prompt_multiple_choice(line)
else:
prompt = self.build_prompt_vqa(line)
return prompt

def build_prompt_multiple_choice(self, line, prefix=None):
question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f'\n{key}: {item}'
if prefix is None:
prompt = f"{TYPE_PROMPTS['MCQ']} {question}"
else:
prompt = f"{prefix} {question}"

return prompt

def build_prompt_vqa(self, line, prefix=None):
question = line['question']
if prefix is None:
prompt = f"{TYPE_PROMPTS['VQA']} {question}"
else:
prompt = f"{prefix} {question}"
return prompt

def generate_inner(self, message, dataset=None):
from transformers import GenerationConfig
Expand All @@ -44,10 +166,15 @@ def generate_inner(self, message, dataset=None):
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")

# process the image and text
max_crops = self.max_crops
inputs = self.processor.process(
images=[image],
text=prompt
text=prompt,
images_kwargs={
"max_crops": max_crops
}
)

# move inputs to the correct device and make a batch of size 1
Expand All @@ -63,7 +190,16 @@ def generate_inner(self, message, dataset=None):

# only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

# AI2D: map direct answer to letter option
if dataset in ['AI2D_TEST', 'AI2D_TEST_NO_MASK']:
# 'ai2_diagram_no_letter: Which of the following is the magma chamber?\nK\nB\nC\nH'
if 'ai2_diagram_no_letter' in prompt:
options = prompt.split('\n')[1:]
answer = options.index(generated_text)
generated_text = chr(answer + ord('A'))

# print(dataset, prompt, generated_text, inputs['images'].size()) # uncomment to debug

# print the generated text
return generated_text
Loading