Skip to content

Commit

Permalink
benchmark (#113)
Browse files Browse the repository at this point in the history
* add benchmark code

* add output format control

* add ignore

* add gitignore

* rm unused files

* lint file

* delete unused files

---------

Co-authored-by: Fanyi Pu <FPU001@e.ntu.edu.sg>
  • Loading branch information
Luodian and pufanyi authored May 18, 2023
1 parent fc7d83d commit 49fe2e2
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ ofa_compress/*
xformers_model/*
train_*.sh
gpt_playground/*
evaluation/*.json
evaluation/*.html

#yuanhan
tools
Expand Down
199 changes: 199 additions & 0 deletions evaluation/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import requests
import torch
import transformers
import json
from PIL import Image
from otter.modeling_otter import OtterForConditionalGeneration
import argparse


def get_image(url: str) -> Image.Image:
"""
Get image from url
Args:
url (str): url of the image
Returns:
Image.Image: PIL Image
"""
return Image.open(requests.get(url, stream=True).raw)


def get_formatted_prompt(prompt: str) -> str:
"""
Format prompt for GPT
Args:
prompt (str): prompt to be formatted
Returns:
str: formatted prompt
"""
return f"<image> User: {prompt} GPT: <answer>"


def get_response(url: str, prompt: str) -> str:
"""
Get the response of single image and prompt from the model
Args:
url (str): url of the image
prompt (str): the prompt (no need to be formatted)
Returns:
str: response of the model
"""
query_image = get_image(url)
vision_x = (
image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"]
.unsqueeze(1)
.unsqueeze(0)
)
lang_x = model.text_tokenizer(
[
get_formatted_prompt(prompt),
],
return_tensors="pt",
)
generated_text = model.generate(
vision_x=vision_x.to(model.device),
lang_x=lang_x["input_ids"].to(model.device),
attention_mask=lang_x["attention_mask"].to(model.device),
max_new_tokens=256,
num_beams=3,
no_repeat_ngram_size=3,
)
parsed_output = (
model.text_tokenizer.decode(generated_text[0])
.split("<answer>")[1]
.lstrip()
.rstrip()
.split("<|endofchunk|>")[0]
.lstrip()
.rstrip()
.lstrip('"')
.rstrip('"')
)
return parsed_output


def generate_html(output_file, model_version_or_tag):
import json

# Load the data from the JSON file
with open(output_file, "r") as f:
data = json.load(f)

# Start the HTML file
html = """
<!DOCTYPE html>
<html>
<head>
<title>Benchmarking various ver. of Otter</title>
<style>
.column {{
float: left;
width: 33.33%;
padding: 5px;
}}
.row::after {{
content: "";
clear: both;
display: table;
}}
img {{
width: 100%;
height: auto;
}}
</style>
</head>
<body>
<h1>{}</h1>
"""

html = html.format(model_version_or_tag)

# Add the data to the HTML
for item in data:
html += """
<div class="row">
<div class="column">
<img src="{image}" alt="Image">
</div>
<div class="column">
{prompt}
</div>
<div class="column">
{response}
</div>
</div>
""".format(
**item
)

# Close the HTML tags
html += """
</body>
</html>
"""

# Write the HTML string to a file
output_html_path = output_file.replace(".json", ".html")
with open(output_html_path, "w") as f:
f.write(html)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path_or_name",
type=str,
default="luodian/otter-9b-hf",
help="Path or name of the model (HF format)",
)
parser.add_argument(
"--model_version_or_tag",
type=str,
default="apr25_otter",
help="Version or tag of the model",
)
parser.add_argument(
"--input_file",
type=str,
default="evaluation/sample_questions.json",
help="Path of the input file",
)
args = parser.parse_args()

model = OtterForConditionalGeneration.from_pretrained(
args.model_path_or_name, device_map="auto"
)
model.text_tokenizer.padding_side = "left"
tokenizer = model.text_tokenizer
image_processor = transformers.CLIPImageProcessor()

responses = []
with open(args.input_file) as f:
data = json.load(f)
for item in data:
print(f"Processing {item['image']} with prompt {item['prompt']}")
response = get_response(item["image"], item["prompt"])
print(f"Response: {response}")
responses.append(
{
"image": item["image"],
"prompt": item["prompt"],
"response": get_response(item["image"], item["prompt"]),
}
)
json.dump(
responses,
open(f"./evaluation/{args.model_version_or_tag}_outputs.json", "w"),
indent=4,
)

generate_html(
f"./evaluation/{args.model_version_or_tag}_outputs.json",
args.model_version_or_tag,
)
14 changes: 14 additions & 0 deletions evaluation/sample_questions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"image": "https://s2.loli.net/2023/05/17/OyhFvmaH867DwC2.jpg",
"prompt": "Is this scene real?"
},
{
"image": "https://s2.loli.net/2023/05/17/7OaJlqyRri2E8zY.png",
"prompt": "What's written on this image?"
},
{
"image": "https://s2.loli.net/2023/05/17/18egkN6aYGDAtsM.png",
"prompt": "Which city is this?"
}
]

0 comments on commit 49fe2e2

Please sign in to comment.