Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

Commit

Permalink
Add initial version of milly eval jupyter notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
yukw777 committed May 9, 2023
1 parent 151cbee commit 9f4e045
Showing 1 changed file with 258 additions and 0 deletions.
258 changes: 258 additions & 0 deletions examples/notebooks/eval_ego_vid_blip2_milly.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate EgoVideoBLIP2 on MILLY step detection data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"import json\n",
"import os\n",
"\n",
"import imageio.v3 as iio\n",
"import numpy as np\n",
"from IPython.display import HTML, display\n",
"from pytorchvideo.data.video import VideoPathHandler\n",
"\n",
"\n",
"def display_gif(video_tensor, gif_file_name):\n",
" \"\"\"Prepares and displays a GIF from a video tensor.\n",
"\n",
" The video tensor is expected to have the following shape:\n",
" (num_frames, num_channels, height, width).\n",
" \"\"\"\n",
" iio.imwrite(\n",
" gif_file_name,\n",
" video_tensor.permute(0, 2, 3, 1).numpy().astype(np.uint8),\n",
" extension=\".gif\",\n",
" # infinite loop\n",
" loop=0,\n",
" )\n",
" html = f'<img src=\"{gif_file_name}\" />'\n",
" display(HTML(html))\n",
"\n",
"\n",
"def load_milly_video_steps(video_dir_path):\n",
" json_files = glob.glob(os.path.join(video_dir_path, \"*.json\"))\n",
" assert len(json_files) == 1\n",
" with open(json_files[0]) as f:\n",
" annotation = json.load(f)\n",
"\n",
" video_path_handler = VideoPathHandler()\n",
" video = video_path_handler.video_from_path(os.path.join(video_dir_path, \"pv.mp4\"))\n",
"\n",
" step_dict = {\n",
" int(k): v for k, v in annotation[\"attribute\"][\"1\"][\"options\"].items() if v != \"\"\n",
" }\n",
" step_list = [step_dict[i] for i in range(len(step_dict))]\n",
" return video, step_list, annotation"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Load an arbitrary pinwheel video from MILLY step detection data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"video, step_list, annotation = load_milly_video_steps(\n",
" \"../../MILLYCookbook_media_v007/z/dat/CookBook/\"\n",
" \"MILLYCookbook_media_v007/A_pin/mevo/18/video-0001\"\n",
")\n",
"\n",
"# clean up step_list\n",
"# not sure what the best format would be, so let's just use the one that's\n",
"# closest to training\n",
"# step_list = [re.match(r'(\\d+: )?(.+)', step).group(2) for step in step_list]\n",
"step_list = [\n",
" \"the camera wearer starts\",\n",
" \"the camera wearer places tortilla\",\n",
" \"the camera wearer scoops butter and spread butter\",\n",
" \"the camera wearer cleans knife\",\n",
" \"the camera wearer scoops jelly and spread jelly\",\n",
" \"the camera wearer cleans knife\",\n",
" \"the camera wearer rolls tortilla\",\n",
" \"the camera wearer inserts toothpick\",\n",
" \"the camera wearer trims tortilla\",\n",
" \"the camera wearer slides floss\",\n",
" \"the camera wearer slices tortilla\",\n",
" \"the camera wearer continues slicing\",\n",
" \"the camera wearer places pinwheels\",\n",
" \"the camera wearer ends\",\n",
"]\n",
"print(step_list)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Load `ego-video-blip2-opt-2.7b-subsample-8`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import Blip2Processor\n",
"\n",
"from video_blip2.model import VideoBlip2ForConditionalGeneration\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"pretrained = \"../../checkpoints/ego-video-blip2/ego-video-blip2-opt-2.7b-subsample-8\"\n",
"processor = Blip2Processor.from_pretrained(pretrained)\n",
"model = VideoBlip2ForConditionalGeneration.from_pretrained(pretrained).to(device)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Perform Video QA as a sanity check."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"Question: what is the camera wearer doing? Answer:\"\n",
"\n",
"for clip_id, annotated_clip in annotation[\"metadata\"].items():\n",
" start, end = annotated_clip[\"z\"]\n",
" # sample a frame every 30 frames, i.e., 1 FPS\n",
" # (channel, time, height, width)\n",
" frames = video.get_clip(start, end)[\"video\"][:, ::30, ...]\n",
" display_gif(frames.permute(1, 0, 2, 3), f\"{clip_id}.gif\")\n",
" inputs = processor(\n",
" images=frames.permute(1, 0, 2, 3), text=prompt, return_tensors=\"pt\"\n",
" ).to(device)\n",
" inputs[\"pixel_values\"] = inputs[\"pixel_values\"].permute(1, 0, 2, 3).unsqueeze(0)\n",
" print(f\"inputs: {({k: v.size() for k, v in inputs.items()})}\")\n",
" with torch.no_grad():\n",
" generated_ids = model.generate(**inputs)\n",
" generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[\n",
" 0\n",
" ].strip()\n",
" print(f\"generated_text: {generated_text}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Run evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"from video_blip2.data.utils import (\n",
" DataCollatorForVideoSeq2Seq,\n",
" generate_input_ids_and_labels,\n",
")\n",
"\n",
"collator = DataCollatorForVideoSeq2Seq(processor.tokenizer)\n",
"input_list = [\n",
" generate_input_ids_and_labels(\n",
" processor.tokenizer,\n",
" \"Question: What is the camera wearer doing? Answer:\",\n",
" text,\n",
" model.config.use_decoder_only_language_model,\n",
" )\n",
" for text in step_list\n",
"]\n",
"\n",
"for clip_id, annotated_clip in annotation[\"metadata\"].items():\n",
" start, end = annotated_clip[\"z\"]\n",
" # sample a frame every 30 frames, i.e., 1 FPS\n",
" # (channel, time, height, width)\n",
" clip = video.get_clip(start, end)[\"video\"][:, ::30, ...]\n",
" display_gif(clip.permute(1, 0, 2, 3), f\"{clip_id}.gif\")\n",
" # process the clip\n",
" clip = processor.image_processor(\n",
" clip.permute(1, 0, 2, 3), return_tensors=\"pt\"\n",
" ).pixel_values.permute(1, 0, 2, 3)\n",
" for item in input_list:\n",
" item[\"pixel_values\"] = clip\n",
" inputs = collator(input_list)\n",
" inputs.to(device)\n",
"\n",
" # calculate lengths of generated texts\n",
" gen_lengths = torch.sum(inputs.labels != -100, dim=-1)\n",
"\n",
" # ignore eos token when calculating log probs\n",
" inputs[\"labels\"][inputs[\"labels\"] == processor.tokenizer.eos_token_id] = -100\n",
"\n",
" with torch.no_grad():\n",
" output = model(**inputs)\n",
" log_probs = F.cross_entropy(\n",
" output.logits.flatten(end_dim=1),\n",
" inputs.labels.flatten(end_dim=1),\n",
" reduction=\"none\",\n",
" )\n",
" normalized_log_probs = -log_probs.view(len(step_list), -1).sum(dim=-1) / gen_lengths\n",
" ground_truth_step = int(annotated_clip[\"av\"][\"1\"])\n",
" print(f\"Ground-truth step: {ground_truth_step} - {step_list[ground_truth_step]}\")\n",
" predicted_step = normalized_log_probs.argmax().item()\n",
" print(f\"Predicted step: {predicted_step} - {step_list[predicted_step]}\")\n",
" # with the current format for the generated text,\n",
" # this doesn't seem to matter as much.\n",
" # predicted_step_wout_start_end = normalized_log_probs[1:-2].argmax().item() + 1\n",
" # print(f'Predicted step w/out start, end: {predicted_step_wout_start_end}'\n",
" # f' - {step_list[predicted_step_wout_start_end]}')\n",
" for i, log_prob in enumerate(normalized_log_probs.tolist()):\n",
" print(f\"{log_prob:.2f}: {step_list[i]}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "video-blip2-jEv4LXUZ-py3.10",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 9f4e045

Please sign in to comment.