forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add model draft * update docstring * add tests * support image and video as input * update for better handling of mixed input and clean-up a bit * bug when mixed inputs & add tests * Update README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Merge remote-tracking branch 'upstream/main' into video_llava * link to abstract of paper in README * fix test * fix-copies * make tests happy * skip docstest for now * do not run doctest for now * Update src/transformers/models/video_llava/processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * address review comments * failing tests * Fix vocab_size in common tests for VLMs * codestyle * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/video_llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/video_llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/video_llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * PR suggestions * fix-copies * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add full example in docs * clean-up with new model-id * [run-slow] video_llava * update docstring * [run-slow] video_llava * remove all achive maps * fix some tests * test was supposed to be skipped for llava :) --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
- Loading branch information
1 parent
b8aee2e
commit bd9f4d7
Showing
25 changed files
with
2,637 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# Video-LLaVA | ||
|
||
## Overview | ||
|
||
Video-LLaVa is an open-source multimodal LLM trained by fine-tuning LlamA/Vicuna on multimodal instruction-following data generated by Llava1.5 and VideChat. It is an auto-regressive language model, based on the transformer architecture. Video-LLaVa unifies visual representations to the language feature space, and enables an LLM to perform visual reasoning capabilities on both images and videos simultaneously. | ||
|
||
|
||
The Video-LLaVA model was proposed in [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/abs/2311.10122) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. | ||
|
||
The abstract from the paper is the following: | ||
|
||
*The Large Vision-Language Model (LVLM) has enhanced the performance of various downstream tasks in | ||
visual-language understanding. Most existing approaches | ||
encode images and videos into separate feature spaces, | ||
which are then fed as inputs to large language models. | ||
However, due to the lack of unified tokenization for images and videos, namely misalignment before projection, it | ||
becomes challenging for a Large Language Model (LLM) | ||
to learn multi-modal interactions from several poor projection layers. In this work, we unify visual representation into the language feature space to advance the foundational LLM towards a unified LVLM. As a result, we establish a simple but robust LVLM baseline, Video-LLaVA, | ||
which learns from a mixed dataset of images and videos, | ||
mutually enhancing each other. Video-LLaVA achieves superior performances on a broad range of 9 image benchmarks across 5 image question-answering datasets and 4 | ||
image benchmark toolkits. Additionally, our Video-LLaVA | ||
also outperforms Video-ChatGPT by 5.8%, 9.9%, 18.6%, | ||
and 10.1% on MSRVTT, MSVD, TGIF, and ActivityNet, respectively. Notably, extensive experiments demonstrate that | ||
Video-LLaVA mutually benefits images and videos within | ||
a unified visual representation, outperforming models designed specifically for images or videos. We aim for this | ||
work to provide modest insights into the multi-modal inputs | ||
for the LLM* | ||
|
||
Tips: | ||
|
||
- We advise users to use padding_side="left" when computing batched generation as it leads to more accurate results. Simply make sure to call processor.tokenizer.padding_side = "left" before generating. | ||
|
||
- Note the model has not been explicitly trained to process multiple images/videos in the same prompt, although this is technically possible, you may experience inaccurate results. | ||
|
||
- For better results, we recommend users prompt the model with the correct prompt format: | ||
|
||
|
||
```python | ||
import av | ||
import torch | ||
import numpy as np | ||
import requests | ||
from PIL import Image | ||
from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor | ||
|
||
def read_video_pyav(container, indices): | ||
''' | ||
Decode the video with PyAV decoder. | ||
Args: | ||
container (`av.container.input.InputContainer`): PyAV container. | ||
indices (`List[int]`): List of frame indices to decode. | ||
Returns: | ||
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). | ||
''' | ||
frames = [] | ||
container.seek(0) | ||
start_index = indices[0] | ||
end_index = indices[-1] | ||
for i, frame in enumerate(container.decode(video=0)): | ||
if i > end_index: | ||
break | ||
if i >= start_index and i in indices: | ||
frames.append(frame) | ||
return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | ||
|
||
|
||
model = VideoLlavaForConditionalGeneration.from_pretrained("RaushanTurganbay/video-llava-7b-hf", device_map="auto") | ||
processor = VideoLlavaProcessor.from_pretrained("RaushanTurganbay/video-llava-7b-hf") | ||
|
||
video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset") | ||
|
||
container = av.open(video_path) | ||
total_frames = container.streams.video[0].frames | ||
indices = np.arange(0, total_frames, total_frames / 8).astype(int) | ||
video = read_video_pyav(container, indices) | ||
|
||
prompt = "USER: <video>Why is this funny? ASSISTANT:" | ||
inputs = processor(text=prompt, videos=video, return_tensors="pt") | ||
|
||
out = model.generate(**inputs, max_new_tokens=40) | ||
print(processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)) | ||
``` | ||
|
||
For multiple turns conversation change the prompt to: | ||
|
||
```bash | ||
"USER: <video>What do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:" | ||
``` | ||
|
||
- Note that the video inputs should have exactly 8 frames at the input, since the models were trained in that setting. | ||
|
||
|
||
|
||
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay). | ||
The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA). | ||
|
||
|
||
## VideoLlavaConfig | ||
|
||
[[autodoc]] VideoLlavaConfig | ||
|
||
## VideoLlavaImageProcessor | ||
|
||
[[autodoc]] VideoLlavaImageProcessor | ||
|
||
## VideoLlavaProcessor | ||
|
||
[[autodoc]] VideoLlavaProcessor | ||
|
||
## VideoLlavaForConditionalGeneration | ||
|
||
[[autodoc]] VideoLlavaForConditionalGeneration | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -242,6 +242,7 @@ | |
unispeech_sat, | ||
univnet, | ||
upernet, | ||
video_llava, | ||
videomae, | ||
vilt, | ||
vipllava, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available | ||
|
||
|
||
_import_structure = { | ||
"configuration_video_llava": ["VideoLlavaConfig"], | ||
"processing_video_llava": ["VideoLlavaProcessor"], | ||
} | ||
|
||
try: | ||
if not is_vision_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
_import_structure["image_processing_video_llava"] = ["VideoLlavaImageProcessor"] | ||
|
||
try: | ||
if not is_torch_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
_import_structure["modeling_video_llava"] = [ | ||
"VideoLlavaPreTrainedModel", | ||
"VideoLlavaForConditionalGeneration", | ||
] | ||
|
||
if TYPE_CHECKING: | ||
from .configuration_video_llava import ( | ||
VideoLlavaConfig, | ||
) | ||
from .image_processing_video_llava import VideoLlavaProcessor | ||
|
||
try: | ||
if not is_vision_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
from .image_processing_video_llava import VideoLlavaImageProcessor | ||
|
||
try: | ||
if not is_torch_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
from .modeling_video_llava import ( | ||
VideoLlavaForConditionalGeneration, | ||
VideoLlavaPreTrainedModel, | ||
) | ||
|
||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |
Oops, something went wrong.