Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

Commit

Permalink
Change VideoBlip2 to VideoBlip
Browse files Browse the repository at this point in the history
  • Loading branch information
yukw777 committed May 17, 2023
1 parent 22dde1f commit 1c60134
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 23 deletions.
8 changes: 4 additions & 4 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from pytorchvideo.data.video import VideoPathHandler
from transformers import Blip2Processor

from video_blip.model import VideoBlip2ForConditionalGeneration
from video_blip.model import VideoBlipForConditionalGeneration


@torch.no_grad()
def respond(
model: VideoBlip2ForConditionalGeneration,
model: VideoBlipForConditionalGeneration,
processor: Blip2Processor,
video_path_handler: VideoPathHandler,
video_path: str,
Expand Down Expand Up @@ -60,7 +60,7 @@ def respond(


def construct_demo(
model: VideoBlip2ForConditionalGeneration,
model: VideoBlipForConditionalGeneration,
processor: Blip2Processor,
video_path_handler: VideoPathHandler,
) -> gr.Blocks:
Expand Down Expand Up @@ -141,7 +141,7 @@ def construct_demo(
args = parser.parse_args()

processor = Blip2Processor.from_pretrained(args.model_name_or_path)
model = VideoBlip2ForConditionalGeneration.from_pretrained(
model = VideoBlipForConditionalGeneration.from_pretrained(
args.model_name_or_path
).to(args.device)
demo = construct_demo(model, processor, VideoPathHandler())
Expand Down
6 changes: 3 additions & 3 deletions examples/notebooks/eval_ego_vid_blip2_milly_coffee.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate EgoVideoBLIP2 on MILLY step detection data"
"# Evaluate VideoBLIP on MILLY step detection data"
]
},
{
Expand Down Expand Up @@ -113,12 +113,12 @@
"import torch\n",
"from transformers import Blip2Processor\n",
"\n",
"from video_blip.model import VideoBlip2ForConditionalGeneration\n",
"from video_blip.model import VideoBlipForConditionalGeneration\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"pretrained = \"../../checkpoints/ego-video-blip2/ego-video-blip2-opt-2.7b-subsample-8\"\n",
"processor = Blip2Processor.from_pretrained(pretrained)\n",
"model = VideoBlip2ForConditionalGeneration.from_pretrained(pretrained).to(device)"
"model = VideoBlipForConditionalGeneration.from_pretrained(pretrained).to(device)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions examples/notebooks/eval_ego_vid_blip2_milly_pinwheel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate EgoVideoBLIP2 on MILLY step detection data"
"# Evaluate VideoBLIP on MILLY step detection data"
]
},
{
Expand Down Expand Up @@ -116,12 +116,12 @@
"import torch\n",
"from transformers import Blip2Processor\n",
"\n",
"from video_blip.model import VideoBlip2ForConditionalGeneration\n",
"from video_blip.model import VideoBlipForConditionalGeneration\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"pretrained = \"../../checkpoints/ego-video-blip2/ego-video-blip2-opt-2.7b-subsample-8\"\n",
"processor = Blip2Processor.from_pretrained(pretrained)\n",
"model = VideoBlip2ForConditionalGeneration.from_pretrained(pretrained).to(device)"
"model = VideoBlipForConditionalGeneration.from_pretrained(pretrained).to(device)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions examples/notebooks/flan_guide_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@
"import torch\n",
"from transformers import Blip2Processor\n",
"\n",
"from video_blip.model import VideoBlip2ForConditionalGeneration\n",
"from video_blip.model import VideoBlipForConditionalGeneration\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"processor = Blip2Processor.from_pretrained(\"Salesforce/blip2-flan-t5-xxl\")\n",
"model = VideoBlip2ForConditionalGeneration.from_pretrained(\n",
"model = VideoBlipForConditionalGeneration.from_pretrained(\n",
" \"Salesforce/blip2-flan-t5-xxl\", torch_dtype=torch.float16\n",
").to(device)\n",
"print(model.generation_config)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Video QA with Video Blip2"
"# Video QA with VideoBLIP"
]
},
{
Expand Down Expand Up @@ -62,12 +62,12 @@
"import torch\n",
"from transformers import Blip2Processor\n",
"\n",
"from video_blip.model import VideoBlip2ForConditionalGeneration\n",
"from video_blip.model import VideoBlipForConditionalGeneration\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"pretrained = \"../../checkpoints/ego-video-blip2/ego-video-blip2-opt-2.7b-subsample-8\"\n",
"processor = Blip2Processor.from_pretrained(pretrained)\n",
"model = VideoBlip2ForConditionalGeneration.from_pretrained(pretrained).to(device)"
"model = VideoBlipForConditionalGeneration.from_pretrained(pretrained).to(device)"
]
},
{
Expand Down Expand Up @@ -146,7 +146,7 @@
"source": [
"pretrained = \"../../checkpoints/ego-video-blip2/ego-video-blip2-flan-t5-xl-subsample-8\"\n",
"processor = Blip2Processor.from_pretrained(pretrained)\n",
"model = VideoBlip2ForConditionalGeneration.from_pretrained(pretrained).to(device)"
"model = VideoBlipForConditionalGeneration.from_pretrained(pretrained).to(device)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions scripts/ego4d/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
clean_narration_text,
generate_input_ids_and_labels,
)
from video_blip.model import VideoBlip2ForConditionalGeneration
from video_blip.model import VideoBlipForConditionalGeneration

PROMPT = "Question: What is the camera wearer doing? Answer:"

Expand Down Expand Up @@ -73,7 +73,7 @@ def train() -> None:
processor = transformers.Blip2Processor.from_pretrained(
model_args.model_name_or_path
)
model = VideoBlip2ForConditionalGeneration.from_pretrained(
model = VideoBlipForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
low_cpu_mem_usage=False if is_deepspeed_zero3_enabled() else True,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from transformers import Blip2VisionConfig

from video_blip.model import VideoBlip2VisionModel
from video_blip.model import VideoBlipVisionModel


@pytest.mark.parametrize("output_hidden_states", [True, False])
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_video_blip_vision_model_forward(
output_attentions: bool,
output_hidden_states: bool,
) -> None:
model = VideoBlip2VisionModel(config)
model = VideoBlipVisionModel(config)
outputs = model(
pixel_values=torch.rand(
# channel is pretty much always 3
Expand Down
6 changes: 3 additions & 3 deletions video_blip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.modeling_outputs import BaseModelOutputWithPooling


class VideoBlip2VisionModel(Blip2VisionModel):
class VideoBlipVisionModel(Blip2VisionModel):
"""A simple, augmented version of Blip2VisionModel to handle videos."""

def forward(
Expand Down Expand Up @@ -92,14 +92,14 @@ def forward(
return (last_hidden_state, pooler_output, hidden_states, attentions)


class VideoBlip2ForConditionalGeneration(Blip2ForConditionalGeneration):
class VideoBlipForConditionalGeneration(Blip2ForConditionalGeneration):
def __init__(self, config: Blip2Config) -> None:
# HACK: we call the grandparent super().__init__() to bypass
# Blip2ForConditionalGeneration.__init__() so we can replace
# self.vision_model
super(Blip2ForConditionalGeneration, self).__init__(config)

self.vision_model = VideoBlip2VisionModel(config.vision_config)
self.vision_model = VideoBlipVisionModel(config.vision_config)

self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
Expand Down

0 comments on commit 1c60134

Please sign in to comment.