-
Notifications
You must be signed in to change notification settings - Fork 242
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add benchmark code * add output format control * add ignore * add gitignore * rm unused files * lint file * delete unused files --------- Co-authored-by: Fanyi Pu <FPU001@e.ntu.edu.sg>
- Loading branch information
Showing
3 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import requests | ||
import torch | ||
import transformers | ||
import json | ||
from PIL import Image | ||
from otter.modeling_otter import OtterForConditionalGeneration | ||
import argparse | ||
|
||
|
||
def get_image(url: str) -> Image.Image: | ||
""" | ||
Get image from url | ||
Args: | ||
url (str): url of the image | ||
Returns: | ||
Image.Image: PIL Image | ||
""" | ||
return Image.open(requests.get(url, stream=True).raw) | ||
|
||
|
||
def get_formatted_prompt(prompt: str) -> str: | ||
""" | ||
Format prompt for GPT | ||
Args: | ||
prompt (str): prompt to be formatted | ||
Returns: | ||
str: formatted prompt | ||
""" | ||
return f"<image> User: {prompt} GPT: <answer>" | ||
|
||
|
||
def get_response(url: str, prompt: str) -> str: | ||
""" | ||
Get the response of single image and prompt from the model | ||
Args: | ||
url (str): url of the image | ||
prompt (str): the prompt (no need to be formatted) | ||
Returns: | ||
str: response of the model | ||
""" | ||
query_image = get_image(url) | ||
vision_x = ( | ||
image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"] | ||
.unsqueeze(1) | ||
.unsqueeze(0) | ||
) | ||
lang_x = model.text_tokenizer( | ||
[ | ||
get_formatted_prompt(prompt), | ||
], | ||
return_tensors="pt", | ||
) | ||
generated_text = model.generate( | ||
vision_x=vision_x.to(model.device), | ||
lang_x=lang_x["input_ids"].to(model.device), | ||
attention_mask=lang_x["attention_mask"].to(model.device), | ||
max_new_tokens=256, | ||
num_beams=3, | ||
no_repeat_ngram_size=3, | ||
) | ||
parsed_output = ( | ||
model.text_tokenizer.decode(generated_text[0]) | ||
.split("<answer>")[1] | ||
.lstrip() | ||
.rstrip() | ||
.split("<|endofchunk|>")[0] | ||
.lstrip() | ||
.rstrip() | ||
.lstrip('"') | ||
.rstrip('"') | ||
) | ||
return parsed_output | ||
|
||
|
||
def generate_html(output_file, model_version_or_tag): | ||
import json | ||
|
||
# Load the data from the JSON file | ||
with open(output_file, "r") as f: | ||
data = json.load(f) | ||
|
||
# Start the HTML file | ||
html = """ | ||
<!DOCTYPE html> | ||
<html> | ||
<head> | ||
<title>Benchmarking various ver. of Otter</title> | ||
<style> | ||
.column {{ | ||
float: left; | ||
width: 33.33%; | ||
padding: 5px; | ||
}} | ||
.row::after {{ | ||
content: ""; | ||
clear: both; | ||
display: table; | ||
}} | ||
img {{ | ||
width: 100%; | ||
height: auto; | ||
}} | ||
</style> | ||
</head> | ||
<body> | ||
<h1>{}</h1> | ||
""" | ||
|
||
html = html.format(model_version_or_tag) | ||
|
||
# Add the data to the HTML | ||
for item in data: | ||
html += """ | ||
<div class="row"> | ||
<div class="column"> | ||
<img src="{image}" alt="Image"> | ||
</div> | ||
<div class="column"> | ||
{prompt} | ||
</div> | ||
<div class="column"> | ||
{response} | ||
</div> | ||
</div> | ||
""".format( | ||
**item | ||
) | ||
|
||
# Close the HTML tags | ||
html += """ | ||
</body> | ||
</html> | ||
""" | ||
|
||
# Write the HTML string to a file | ||
output_html_path = output_file.replace(".json", ".html") | ||
with open(output_html_path, "w") as f: | ||
f.write(html) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_path_or_name", | ||
type=str, | ||
default="luodian/otter-9b-hf", | ||
help="Path or name of the model (HF format)", | ||
) | ||
parser.add_argument( | ||
"--model_version_or_tag", | ||
type=str, | ||
default="apr25_otter", | ||
help="Version or tag of the model", | ||
) | ||
parser.add_argument( | ||
"--input_file", | ||
type=str, | ||
default="evaluation/sample_questions.json", | ||
help="Path of the input file", | ||
) | ||
args = parser.parse_args() | ||
|
||
model = OtterForConditionalGeneration.from_pretrained( | ||
args.model_path_or_name, device_map="auto" | ||
) | ||
model.text_tokenizer.padding_side = "left" | ||
tokenizer = model.text_tokenizer | ||
image_processor = transformers.CLIPImageProcessor() | ||
|
||
responses = [] | ||
with open(args.input_file) as f: | ||
data = json.load(f) | ||
for item in data: | ||
print(f"Processing {item['image']} with prompt {item['prompt']}") | ||
response = get_response(item["image"], item["prompt"]) | ||
print(f"Response: {response}") | ||
responses.append( | ||
{ | ||
"image": item["image"], | ||
"prompt": item["prompt"], | ||
"response": get_response(item["image"], item["prompt"]), | ||
} | ||
) | ||
json.dump( | ||
responses, | ||
open(f"./evaluation/{args.model_version_or_tag}_outputs.json", "w"), | ||
indent=4, | ||
) | ||
|
||
generate_html( | ||
f"./evaluation/{args.model_version_or_tag}_outputs.json", | ||
args.model_version_or_tag, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[ | ||
{ | ||
"image": "https://s2.loli.net/2023/05/17/OyhFvmaH867DwC2.jpg", | ||
"prompt": "Is this scene real?" | ||
}, | ||
{ | ||
"image": "https://s2.loli.net/2023/05/17/7OaJlqyRri2E8zY.png", | ||
"prompt": "What's written on this image?" | ||
}, | ||
{ | ||
"image": "https://s2.loli.net/2023/05/17/18egkN6aYGDAtsM.png", | ||
"prompt": "Which city is this?" | ||
} | ||
] |