This repository has been archived by the owner on Dec 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial version of milly eval jupyter notebook
- Loading branch information
Showing
1 changed file
with
258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Evaluate EgoVideoBLIP2 on MILLY step detection data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import glob\n", | ||
"import json\n", | ||
"import os\n", | ||
"\n", | ||
"import imageio.v3 as iio\n", | ||
"import numpy as np\n", | ||
"from IPython.display import HTML, display\n", | ||
"from pytorchvideo.data.video import VideoPathHandler\n", | ||
"\n", | ||
"\n", | ||
"def display_gif(video_tensor, gif_file_name):\n", | ||
" \"\"\"Prepares and displays a GIF from a video tensor.\n", | ||
"\n", | ||
" The video tensor is expected to have the following shape:\n", | ||
" (num_frames, num_channels, height, width).\n", | ||
" \"\"\"\n", | ||
" iio.imwrite(\n", | ||
" gif_file_name,\n", | ||
" video_tensor.permute(0, 2, 3, 1).numpy().astype(np.uint8),\n", | ||
" extension=\".gif\",\n", | ||
" # infinite loop\n", | ||
" loop=0,\n", | ||
" )\n", | ||
" html = f'<img src=\"{gif_file_name}\" />'\n", | ||
" display(HTML(html))\n", | ||
"\n", | ||
"\n", | ||
"def load_milly_video_steps(video_dir_path):\n", | ||
" json_files = glob.glob(os.path.join(video_dir_path, \"*.json\"))\n", | ||
" assert len(json_files) == 1\n", | ||
" with open(json_files[0]) as f:\n", | ||
" annotation = json.load(f)\n", | ||
"\n", | ||
" video_path_handler = VideoPathHandler()\n", | ||
" video = video_path_handler.video_from_path(os.path.join(video_dir_path, \"pv.mp4\"))\n", | ||
"\n", | ||
" step_dict = {\n", | ||
" int(k): v for k, v in annotation[\"attribute\"][\"1\"][\"options\"].items() if v != \"\"\n", | ||
" }\n", | ||
" step_list = [step_dict[i] for i in range(len(step_dict))]\n", | ||
" return video, step_list, annotation" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Load an arbitrary pinwheel video from MILLY step detection data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"video, step_list, annotation = load_milly_video_steps(\n", | ||
" \"../../MILLYCookbook_media_v007/z/dat/CookBook/\"\n", | ||
" \"MILLYCookbook_media_v007/A_pin/mevo/18/video-0001\"\n", | ||
")\n", | ||
"\n", | ||
"# clean up step_list\n", | ||
"# not sure what the best format would be, so let's just use the one that's\n", | ||
"# closest to training\n", | ||
"# step_list = [re.match(r'(\\d+: )?(.+)', step).group(2) for step in step_list]\n", | ||
"step_list = [\n", | ||
" \"the camera wearer starts\",\n", | ||
" \"the camera wearer places tortilla\",\n", | ||
" \"the camera wearer scoops butter and spread butter\",\n", | ||
" \"the camera wearer cleans knife\",\n", | ||
" \"the camera wearer scoops jelly and spread jelly\",\n", | ||
" \"the camera wearer cleans knife\",\n", | ||
" \"the camera wearer rolls tortilla\",\n", | ||
" \"the camera wearer inserts toothpick\",\n", | ||
" \"the camera wearer trims tortilla\",\n", | ||
" \"the camera wearer slides floss\",\n", | ||
" \"the camera wearer slices tortilla\",\n", | ||
" \"the camera wearer continues slicing\",\n", | ||
" \"the camera wearer places pinwheels\",\n", | ||
" \"the camera wearer ends\",\n", | ||
"]\n", | ||
"print(step_list)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Load `ego-video-blip2-opt-2.7b-subsample-8`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from transformers import Blip2Processor\n", | ||
"\n", | ||
"from video_blip2.model import VideoBlip2ForConditionalGeneration\n", | ||
"\n", | ||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | ||
"pretrained = \"../../checkpoints/ego-video-blip2/ego-video-blip2-opt-2.7b-subsample-8\"\n", | ||
"processor = Blip2Processor.from_pretrained(pretrained)\n", | ||
"model = VideoBlip2ForConditionalGeneration.from_pretrained(pretrained).to(device)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Perform Video QA as a sanity check." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt = \"Question: what is the camera wearer doing? Answer:\"\n", | ||
"\n", | ||
"for clip_id, annotated_clip in annotation[\"metadata\"].items():\n", | ||
" start, end = annotated_clip[\"z\"]\n", | ||
" # sample a frame every 30 frames, i.e., 1 FPS\n", | ||
" # (channel, time, height, width)\n", | ||
" frames = video.get_clip(start, end)[\"video\"][:, ::30, ...]\n", | ||
" display_gif(frames.permute(1, 0, 2, 3), f\"{clip_id}.gif\")\n", | ||
" inputs = processor(\n", | ||
" images=frames.permute(1, 0, 2, 3), text=prompt, return_tensors=\"pt\"\n", | ||
" ).to(device)\n", | ||
" inputs[\"pixel_values\"] = inputs[\"pixel_values\"].permute(1, 0, 2, 3).unsqueeze(0)\n", | ||
" print(f\"inputs: {({k: v.size() for k, v in inputs.items()})}\")\n", | ||
" with torch.no_grad():\n", | ||
" generated_ids = model.generate(**inputs)\n", | ||
" generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[\n", | ||
" 0\n", | ||
" ].strip()\n", | ||
" print(f\"generated_text: {generated_text}\")" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Run evaluation." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch.nn.functional as F\n", | ||
"\n", | ||
"from video_blip2.data.utils import (\n", | ||
" DataCollatorForVideoSeq2Seq,\n", | ||
" generate_input_ids_and_labels,\n", | ||
")\n", | ||
"\n", | ||
"collator = DataCollatorForVideoSeq2Seq(processor.tokenizer)\n", | ||
"input_list = [\n", | ||
" generate_input_ids_and_labels(\n", | ||
" processor.tokenizer,\n", | ||
" \"Question: What is the camera wearer doing? Answer:\",\n", | ||
" text,\n", | ||
" model.config.use_decoder_only_language_model,\n", | ||
" )\n", | ||
" for text in step_list\n", | ||
"]\n", | ||
"\n", | ||
"for clip_id, annotated_clip in annotation[\"metadata\"].items():\n", | ||
" start, end = annotated_clip[\"z\"]\n", | ||
" # sample a frame every 30 frames, i.e., 1 FPS\n", | ||
" # (channel, time, height, width)\n", | ||
" clip = video.get_clip(start, end)[\"video\"][:, ::30, ...]\n", | ||
" display_gif(clip.permute(1, 0, 2, 3), f\"{clip_id}.gif\")\n", | ||
" # process the clip\n", | ||
" clip = processor.image_processor(\n", | ||
" clip.permute(1, 0, 2, 3), return_tensors=\"pt\"\n", | ||
" ).pixel_values.permute(1, 0, 2, 3)\n", | ||
" for item in input_list:\n", | ||
" item[\"pixel_values\"] = clip\n", | ||
" inputs = collator(input_list)\n", | ||
" inputs.to(device)\n", | ||
"\n", | ||
" # calculate lengths of generated texts\n", | ||
" gen_lengths = torch.sum(inputs.labels != -100, dim=-1)\n", | ||
"\n", | ||
" # ignore eos token when calculating log probs\n", | ||
" inputs[\"labels\"][inputs[\"labels\"] == processor.tokenizer.eos_token_id] = -100\n", | ||
"\n", | ||
" with torch.no_grad():\n", | ||
" output = model(**inputs)\n", | ||
" log_probs = F.cross_entropy(\n", | ||
" output.logits.flatten(end_dim=1),\n", | ||
" inputs.labels.flatten(end_dim=1),\n", | ||
" reduction=\"none\",\n", | ||
" )\n", | ||
" normalized_log_probs = -log_probs.view(len(step_list), -1).sum(dim=-1) / gen_lengths\n", | ||
" ground_truth_step = int(annotated_clip[\"av\"][\"1\"])\n", | ||
" print(f\"Ground-truth step: {ground_truth_step} - {step_list[ground_truth_step]}\")\n", | ||
" predicted_step = normalized_log_probs.argmax().item()\n", | ||
" print(f\"Predicted step: {predicted_step} - {step_list[predicted_step]}\")\n", | ||
" # with the current format for the generated text,\n", | ||
" # this doesn't seem to matter as much.\n", | ||
" # predicted_step_wout_start_end = normalized_log_probs[1:-2].argmax().item() + 1\n", | ||
" # print(f'Predicted step w/out start, end: {predicted_step_wout_start_end}'\n", | ||
" # f' - {step_list[predicted_step_wout_start_end]}')\n", | ||
" for i, log_prob in enumerate(normalized_log_probs.tolist()):\n", | ||
" print(f\"{log_prob:.2f}: {step_list[i]}\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "video-blip2-jEv4LXUZ-py3.10", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |