Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Video-LLaVa now available in the Transformers library! #156

Open
zucchini-nlp opened this issue May 15, 2024 · 55 comments
Open

Video-LLaVa now available in the Transformers library! #156

zucchini-nlp opened this issue May 15, 2024 · 55 comments

Comments

@zucchini-nlp
Copy link

zucchini-nlp commented May 15, 2024

Hey!

Video-LLaVa is now available in the Transformers library! Feel free to check it out here. Thanks to @LinB203 for helping to ship the model 🤗

To get the model, update transformers by running: !pip install --upgrade git+https://github.com/huggingface/transformers.git. Inference with videos can be done as follows:

import av
import numpy as n
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")

prompt = "USER: <video>Why is this video funny? ASSISTANT:"
video_path = "YOUR-LOCAL-VIDEO-PATH
container = av.open(video_path)

# sample uniformly 8 frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)

inputs = processor(text=prompt, videos=clip, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=80)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
>>> 'USER:  Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.'

Check out:

@LinB203
Copy link
Member

LinB203 commented May 15, 2024

It's a great feat. Thank you for your generous help!

@rhelck
Copy link

rhelck commented May 16, 2024

@zucchini-nlp I'm seeing the following problem

File "/home/rhelck/videotest.py", line 3, in
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration
ImportError: cannot import name 'VideoLlavaProcessor' from 'transformers' (/home/rhelck/videovenv/lib/python3.10/site-packages/transformers/init.py)

The older example works fine for me, though. I reinstalled transfomers in a new venv for this by the way

@zucchini-nlp
Copy link
Author

@rhelck hey! Did you install transformers from main as follows? Video-LLaVa will be included in the next release, which I believe will be in a few days. For now you can get it from main 🤗

!pip install --upgrade git+https://github.com/huggingface/transformers.git

@darshana1406
Copy link

@zucchini-nlp I want to distribute the model on multiple gpus.

raise ValueError(
ValueError: VideoLlavaForConditionalGeneration does not support device_map='auto'. To implement support, the model class needs to implement the _no_split_modules attribute.

@zucchini-nlp
Copy link
Author

@darshana1406 could you open this as issue in transformers and tag me there, and I will add the "device_map" support roughly by the end of this week

Also you are welcome to open a PR, if you think you are willing to, we are always happy for community contributions 🤗

@rhelck
Copy link

rhelck commented May 17, 2024

@zucchini-nlp That worked perfectly, thanks!

@IsabelJimenez99
Copy link

Can it also be used with images as before or only for videos?

@zucchini-nlp
Copy link
Author

@IsabelJimenez99 , yes, the model can be used with images / videos / mix of image and video. Check out a colab notebook for inference examples with different input modalities

@IsabelJimenez99
Copy link

Ah, ok. Sorry, I hadn't seen the collab. Thank you very much and excellent work. Congratulations!

@BalloutAI
Copy link

Can we use this library for fine-tuning as well or only for inference? If we can, is there documentation on how to use it properly?
thanks!

@zucchini-nlp
Copy link
Author

@BalloutAI Yes, we can. I am preparing a tutorial notebook for fine-tuning and will add it here, when it's done

@BalloutAI
Copy link

Thank you so much! Any expected timeline for that?

@zucchini-nlp
Copy link
Author

@BalloutAI I made a short notebook for finetuning on a small dataset, you can find it here

@IsabelJimenez99
Copy link

IsabelJimenez99 commented May 24, 2024

I am testing with the model ‘LanguageBind/Video-LLaVA-7B-hf’ and every time I run it on an image I get a different answer. I would like to know how much confidence the model has in each response, could I know?

@zucchini-nlp
Copy link
Author

@IsabelJimenez99 You mean the model gives different generation every time, even if you keep the same image and prompt? That shouldn't be the case, can you share a minimal reproducible code?

Regarding the model's confidence in each response, have a look at this thread which shows how to get probability of each generated token :)

@IsabelJimenez99
Copy link

Yes, it's the same image, same prompt but different answers. The code I used is the same as the one shown in your collab.

This is the code:

import` torch
from` transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration, BitsAndBytesConfig
import requests
from PIL import Image

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

url = "../frames_testeo/00006.jpg"
image = Image.open(url)

model_id = "LanguageBind/Video-LLaVA-7B-hf"
processor = VideoLlavaProcessor.from_pretrained(model_id)
model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config)

# This time we will use a special "<image>" token instead of "<video>"
prompt = "USER: <image>\nWhich types of physical contact between people do you see in this image? Select all that you see from the following list: hand-hand, hand-shoulder, hand-elbow, hand-torso, elbow-shoulder, shoulder-shoulder, or none if there is no contact. Note: physical contact means that the mentioned body parts of different people are directly touching each other, not objects. ASSISTANT:"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)

# Generate
generate_kwargs = {"max_new_tokens":100, "do_sample":True, "top_p":0.9, "top_k":2}
generate_ids = model.generate(**inputs, **generate_kwargs)
generated_text = processor.batch_decode(generate_ids, skip_special_tokens=True)

print(generated_text[0])

On the other hand, I have tested what has happened to me and they propose the following:
outputs = model.generate(inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)

However, I extrapolate that to their code and I get the following error:
generate_kwargs = {"max_new_tokens":100, "do_sample":True, "top_p":0.9, "top_k":2}
outputs = model.generate(**inputs, **generate_kwargs, output_scores=True)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True )

AttributeError: 'Tensor' object has no attribute 'sequences'

@zucchini-nlp
Copy link
Author

@IsabelJimenez99 Ah I see now, the different outputs each time is expected in this case because you have set do_sample=True which samples the next token randomly from logits dustribution, instead of getting the most likely token. To get a deterministic output, please use generate_kwargs = {"max_new_tokens":100} only.

And for the second issue, you need to set "return_dict_in_generate=True, output_scores=True" in the generate kwargs to get scores in the output. Otherwise we only return the generated text. For more details of which arguments you can pass in kwargs and what they mean, see the docs 🤗

@IsabelJimenez99
Copy link

Oh! I understand now, thank you very much! And sorry for the inconvenience

@orrzohar
Copy link

@zucchini-nlp
Does this support batch inferencing for faster evaluations?

@zucchini-nlp
Copy link
Author

@orrzohar yes, the model supports batching. For that you just have to pass the prompts as a list of strings, and also the list of visuals. Also you can do batching with different visual inputs: for ex one prompt has only image and another had only video

prompts = ["<video>USER: What do you see in the video? ASSISTANT:", "<image>USER: What do you see in the image? ASSISTANT:", "<video>USER: more video instructions..."],
inputs = processor(text=prompts image=image, video=[clip, clip_2], return_tensors="pt")

@n2nco
Copy link

n2nco commented May 26, 2024

@orrzohar yes, the model supports batching. For that you just have to pass the prompts as a list of strings, and also the list of visuals. Also you can do batching with different visual inputs: for ex one prompt has only image and another had only video

prompts = ["<video>USER: What do you see in the video? ASSISTANT:", "<image>USER: What do you see in the image? ASSISTANT:", "<video>USER: more video instructions..."],
inputs = processor(text=prompts image=image, video=[clip, clip_2], return_tensors="pt")
clip = read_video_pyav(container, indices)
prompts = ["<video>USER: What do you see in the video? ASSISTANT:", "<video>USER: Describe the man in this video's clothing ASSISTANT:"]
inputs = processor(text=prompts, videos=[clip, clip], return_tensors="pt", padding=True, truncation=True)

How might one most efficiently batch multiple prompts with 1 single clip/video?

e.g. to achieve batched prompts applied to 1 single video

Passing in videos=[clip, clip] seems to ~double the inference time


btw in case it helps anyone reading:

i had to add padding & truncation args
inputs = processor(text=prompts, videos=[clip, clip2], return_tensors="pt", padding=True, truncation=True)

@zucchini-nlp
Copy link
Author

@n2nco in that case you have to pass the clip multiple times, as you have two separate prompts each with a special "video" token. Transformers cannot align one video for several clips, as we don't know for sure if that was an intention or a mistake in code, so the safe way is to pass in as many clips as there are special "video" tokens :)

@WeizhenWang-1210
Copy link

Just a side note: could you move the fine-tuned notebook to the main page Markdown? It'll be much easier to spot. Much appreciated!

@zucchini-nlp
Copy link
Author

zucchini-nlp commented May 27, 2024

@WeizhenWang-1210 hey! We don't usually add these notebooks in Transformers docs, but you can find this one and many more in our tutorials repo 🤗

@BalloutAI
Copy link

Hey, thanks for the awesome work.
I am trying to use it almost as you are using it, but for some reason I am getting 100% accuracy even before training ( on sanity check I increased it to 20) which is impossible because I checked your demo and the performance was really bad before training. I was wondering if I am doing something wrong in my data handling.:
`def read_video_pyav(video_path, start, end):
"""Reads a video for given start-end timestamps interval and uniformly samples 8 frames of it"""
container = av.open(video_path)
video = container.streams.get(0)[0]

av_timestamps = [
    int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None
]

av_timestamps.sort()
start_id = bisect.bisect_left(av_timestamps, start)
end_id = bisect.bisect_left(av_timestamps, end)

# in case it is a very short video, lets take a longer duration and sample
if end_id  - start_id < 10:
    end_id += 10
    start_id -= 10

end_id = min(len(av_timestamps) - 1, end_id)
start_id = max(1, start_id)
indices = np.linspace(start_id, end_id, 8).astype(int)

frames = []
container.seek(0)
for i, frame in enumerate(container.decode(video=0)):
    if i > end_id:
        break
    if i >= start_id and i in indices:
        frames.append(frame)
assert len(frames) == 8, f"Got {len(frames)} frames but should be 8. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames."
return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def collate_read_video(example, path):
clip = read_video_pyav(example["video"], example.get("start", 1), example.get("end", 1e+10))
example["clip"] = clip
return example

def load_videos_from_directory(directory):
data = {"video": [], "label": []}
for label in ["True", "False"]:
folder = os.path.join(directory, label)
for filename in os.listdir(folder):
if filename.endswith(".mp4"): #
data["video"].append(os.path.join(folder, filename))
data["label"].append(1 if label == "True" else 0)
return data

data = load_videos_from_directory("/mypath")
hf_dataset = HFDataset.from_dict(data)
dataset = hf_dataset.train_test_split(test_size=0.2)

dataset = dataset.map(collate_read_video, batched=False, fn_kwargs={"path": ""}, writer_batch_size= 100)

processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right

class VideoLlavaDataset(Dataset):
"""
PyTorch Dataset for VideoLlavaDataset. This class takes a HuggingFace Dataset as input.
"""

def __init__(
    self,
    dataset: HFDataset,
):
    super().__init__()
    self.dataset = dataset

def __len__(self) -> int:
    return len(self.dataset)

def __getitem__(self, idx: int):
    sample = self.dataset[idx]
    clip = np.array(sample["clip"])

    label = sample["label"]
    label_text = "True" if label == 1 else "False"
    mult_choice = "True or False"

    prompt = f"USER: <video>\nAnswer the following question based on the video by {mult_choice}. " \
             f"ASSISTANT: Answer: {label_text}"

    return prompt, clip

def train_collate_fn(examples):
videos = []
texts = []
texts, videos = list(zip(*examples))

batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
pixel_values_videos = batch["pixel_values_videos"]
labels = batch["labels"]

return input_ids, attention_mask, pixel_values_videos, labels

def eval_collate_fn(examples):
videos = []
texts = []
texts, videos = list(zip(*examples))
texts = [text for text in texts]

batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
pixel_values_videos = batch["pixel_values_videos"]
answer_choice = [text.split("Answer: ")[-1] for text in texts]  # Extract the answer text
return input_ids, attention_mask, pixel_values_videos, answer_choice

train_dataset = VideoLlavaDataset(dataset["train"])
eval_dataset = VideoLlavaDataset(dataset["test"])

class VideoLlavaModelPLModule(L.LightningModule):
def init(self, config, processor, model):
super().init()
self.config = config
self.processor = processor
self.model = model

    self.batch_size = config.get("batch_size")
    self.predictions = []
    self.answers = []

def training_step(self, batch, batch_idx):

    input_ids, attention_mask, pixel_values_videos, labels = batch

    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pixel_values_videos=pixel_values_videos,
        labels=labels
    )
    loss = outputs.loss

    self.log("train_loss", loss)

    return loss

def validation_step(self, batch, batch_idx, dataset_idx=0):
   input_ids, attention_mask, pixel_values_videos, answers = batch
# Autoregressively generate token IDs
   generated_ids = self.model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    pixel_values_videos=pixel_values_videos,
    max_new_tokens=MAX_LENGTH,
    do_sample=False,
)


# Decode the generated token IDs into text, chopping off the prompt
   decoded_predictions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

# Extract the word after "Answer: "
   predictions = []
   for pred in decoded_predictions:
       if "Answer:" in pred:
        answer_part = pred.split("Answer:")[-1].strip()
        predictions.append(answer_part.split()[0])  # Get the first word after "Answer:"
       else:
        predictions.append("")  # Handle cases where "Answer:" is not found

   correct = 0
   for pred, answer in zip(predictions, answers):
      normalized_pred = pred.strip().lower()
      print(normalized_pred)
      normalized_answer = answer.strip().lower()
      print(normalized_answer)
      correct += (normalized_pred == normalized_answer)

   accuracy = correct / len(answers)


# Store the predictions and answers for epoch-end processing
   self.predictions.extend(predictions)
   self.answers.extend(answers)

   return correct


def on_validation_epoch_end(self):
   correct = sum([pred.strip().lower() == ans.strip().lower() for pred, ans in zip(self.predictions, self.answers)])
   accuracy = correct / len(self.answers)
   print(len(self.answers))

   print(f"on_Validation Accuracy: {accuracy * 100:.2f}%")

`

@n2nco
Copy link

n2nco commented May 27, 2024

think it'd be straight forward to swap the vicuna-7b for a llama-3-8b base? e.g. https://huggingface.co/lmms-lab/llama3-llava-next-8b

@zucchini-nlp
Copy link
Author

@BalloutAI , i am not sure where is the "question" that you're referring to in the prompt, and it's weird that the models is getting 100%. Did you try verifying the validation dataloader is correct (shapes and content), and turning on verbose mode to print the prediction/answers?

@n2nco yes, swapping the backbone LLM should be easy by tweaking with the model's config, but the new model would require training. AFAIK the llava-Next you're pointing to can do video generation even if it wasn't trained for that. We're working on adding those in transformers 😄

@BalloutAI
Copy link

Yeah, I have tried printing, and it is getting them correctly ['USER: \nAnswer the following question based on the video by True or False. ASSISTANT: Answer: True']. and it is answering them correctly no matter what the question is for some reason. My guess was that I am feeding the answers to the model directly somehow, but I cant find the problem, because I am getting my answer from the decoded_predictions.

@zucchini-nlp
Copy link
Author

@BalloutAI Ah, sorry, you're right! Didn't see you had a different way of collate_fn. In the eval_collate when you feed the text to tokenizer, you have to get rid of the answer first.

texts = [text.split("Answer: ")[-1] for text in texts]  # Extract text w/o answer
batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

@BalloutAI
Copy link

Awesome, thx! I expected that!

@caichuang0415
Copy link

Hey!

Video-LLaVa is now available in the Transformers library! Feel free to check it out here. Thanks to @LinB203 for helping to ship the model 🤗

To get the model, update transformers by running: !pip install --upgrade git+https://github.com/huggingface/transformers.git. Inference with videos can be done as follows:

import av
import numpy as n
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")

prompt = "USER: <video>Why is this video funny? ASSISTANT:"
video_path = "YOUR-LOCAL-VIDEO-PATH
container = av.open(video_path)

# sample uniformly 8 frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)

inputs = processor(text=prompt, videos=clip, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=80)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
>>> 'USER:  Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.'

Check out:

Thanks for your contribution. But I came across a bug: ValueError: Video pixel values should have exactly 8 frames but foung 24. I try to sample 24 frames from a video and it shows this bug. Does it only support sample 8 frames from a video? how can we put more frames into it or put the whole video?

@zucchini-nlp
Copy link
Author

@caichuang0415 hey! Yes, since VIdeoLlava was trained with 8 frames, we currently support only 8-frame videos. You can open a PR if you want to give it a chance, otherwise I'll take a look at it next week :)

@zucchini-nlp
Copy link
Author

@caichuang0415 now Video-LLaVa can work with any number of frames at input, But note that inference with more than 8 frames degrades quality, as the model wasn't trained in that setting. I recommend to tune with 24 frames first, if you want good performance.

To get the updated version, please update transformers with:
!pip install --upgrade git+https://github.com/huggingface/transformers.git

@caichuang0415
Copy link

@caichuang0415 now Video-LLaVa can work with any number of frames at input, But note that inference with more than 8 frames degrades quality, as the model wasn't trained in that setting. I recommend to tune with 24 frames first, if you want good performance.

To get the updated version, please update transformers with: !pip install --upgrade git+https://github.com/huggingface/transformers.git

thanks for your updating! I will take your advise and make more experiments

@zucchini-nlp
Copy link
Author

@s-s-la which notebook you're using? The one I linked above leads to VideoLlava and works in 4.42.

The error message mentions another model which I'll merge into transformers on Monday and post about it in LlavaNext repo ;)

@sherlock666
Copy link

sherlock666 commented Jun 27, 2024

@zucchini-nlp
how to finetune with more sample frames ?
the comment for video llava finetune said:

#We sample 8 frames for tuning following the original paper
#But we can increase the number of frames for longer videos and check out if it helps performance
#Change the below "8" to any number of frames you want, and note that more frames -> more computational resources needed

indices = np.linspace(start_id, end_id, 8).astype(int)
However after i set to 30 and finetune it show:

Traceback (most recent call last):
File "videollava_finetune_original_100.py", line 505, in
trainer.fit(model_module)
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
results = self._run_stage()
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
self._run_sanity_check()
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
val_loop.run()
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
return loop_run(self, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
File "videollava_finetune_original_100.py", line 435, in validation_step
generated_ids = self.model.generate(
File "/usr/local/lib/python3.8/dist-packages/peft/peft_model.py", line 647, in generate
return self.get_base_model().generate(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 1758, in generate
result = self._sample(
File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 2397, in _sample
outputs = self(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/video_llava/modeling_video_llava.py", line 513, in forward
image_outputs, video_outputs = self._get_vision_features(
File "/usr/local/lib/python3.8/dist-packages/transformers/models/video_llava/modeling_video_llava.py", line 377, in _get_vision_features
raise ValueError(f"Video pixel values should have exactly 8 frames but foung {num_frames}")
ValueError: Video pixel values should have exactly 8 frames but foung 30

Does it mean that if i really want to change 8 to 30
I need to fully train model again? if so.....i suggest the comment should be deleted which is confusing.....

Also another question is that

if i set about more then 50 frame, it'll cause error :

OverflowError: There was an overflow with type <class 'list'>. Try to reduce writer_batch_size to have batches smaller than 2GB.
(offset overflow while concatenating arrays)

How can i solve it if i really want to use?

thanks!!!

@zucchini-nlp
Copy link
Author

@sherlock666 can you update your transformers version and install from main with !pip install --upgrade git+https://github.com/huggingface/transformers.git ?

@sherlock666
Copy link

@sherlock666 can you update your transformers version and install from main with !pip install --upgrade git+https://github.com/huggingface/transformers.git ?

Thanks for reply

So do you mean that the latest transformer actually won't cause those two problems?

@zucchini-nlp
Copy link
Author

It will solve the first problem. The second can be solved by decreasing writer_batch_size as the error msg says. The default is 1000 afaik.

The issue is that when you get more frames and if your videos are high-resolution, you'll end up with a memory-consuming batches. I had similar problem with another model (at 8 frames). You can also consider doing collate and "read_video" in one dataset.map() so that we don't have to "write" the unprocessed video. In that case each video will have a fixed 336x336 size and that will lower your memory consumption per batch.

Hope it's clear :)

@sherlock666
Copy link

sherlock666 commented Jun 27, 2024

It will solve the first problem. The second can be solved by decreasing writer_batch_size as the error msg says. The default is 1000 afaik.

The issue is that when you get more frames and if your videos are high-resolution, you'll end up with a memory-consuming batches. I had similar problem with another model (at 8 frames). You can also consider doing collate and "read_video" in one dataset.map() so that we don't have to "write" the unprocessed video. In that case each video will have a fixed 336x336 size and that will lower your memory consumption per batch.

Hope it's clear :)

I just check , i'm using docker with latest transformer version (4.41.2)
Or....Should I change to certain old version?
( using Finetune code

Thanks for help

update1:
Kindly remind the error leads to here the "8" was fixed

https://github.com/huggingface/transformers/blob/dc76e9fa7f0d19ff7cfc33bd3a22acd7df167fce/src/transformers/models/video_llava/modeling_video_llava.py#L377

@zucchini-nlp
Copy link
Author

Sorry if I wasn't clear, I meant updating to the version from main since the latest release is planned for today and is not out yet. The cli command above should update version to the main branch :)

@FangXinyu-0913
Copy link

FangXinyu-0913 commented Jul 17, 2024

Sorry for disturbing you @zucchini-nlp. When I try to inference with the script you provided at the top of this issue, the special character '.Ъ' appears for some of the questions in MMBench-Video.
Before, when I also used the eval script of MMBench-Video, this special character didn't appear. Is it related to the Transformers version?
My Transformers Version: 4.42.4
prompt = "USER: <video>Which browser plugin does the blogger recommend? ASSISTANT:"
My generated answer: The blogger recommends the Chrome browser plugin called "LastPass" for password management.Ъ
The video was attached at the end.
Thanks for your help!

lKNB3ZeTYiI_processed.mp4

@zucchini-nlp
Copy link
Author

@FangXinyu-0913 yes, I had same problems and the prompt format should be "USER:

Also, we are starting to support chat templates for that cases, so we can avoid such errors. @LinB203 can you merge my PR on the hub when you have time? :)

@FangXinyu-0913
Copy link

@FangXinyu-0913 yes, I had same problems and the prompt format should be "USER:

Also, we are starting to support chat templates for that cases, so we can avoid such errors. @LinB203 can you merge my PR on the hub when you have time? :)

this prompt format seems not reveal correctly due to some mistakes,can you modify it with code format?(using `` to show the prompt format)

@zucchini-nlp
Copy link
Author

@FangXinyu-0913 sorry, forgot that GH doesn't like <> 😄

prompt = "USER: <video>\n{PROMPT} ASSISTANT:"

@greeksharifa
Copy link

greeksharifa commented Jul 31, 2024

Hello, thanks for your hard work!
I want to use the Video-LLaVA model by transformers the library, but I encountered an error.
My transformers version is 4.42, because 4.41 or 4.43 raise an _reorder_cache error.
same error occurs when installed with pip install -e .

My code is like this..

from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

model_name = "LanguageBind/Video-LLaVA-7B-hf"
self.processor = VideoLlavaProcessor.from_pretrained(model_name)
self.model = VideoLlavaForConditionalGeneration.from_pretrained(
    model_name, 
    cache_dir=os.path.join(cfg.model_cfg.cache_dir, "LanguageBind/"), 
    device_map="auto",
    attn_implementation=None,
)#.to(device)
# (omitted...)
# videos and text_inputs are list of videos, list of strings, respectively.
inputs = self.processor(videos=videos, text=text_inputs, return_tensors="pt", padding=True).to(self.device)
outputs = self.model.generate(
    **inputs,
    # do_sample=False,
    num_beams=5,
    max_new_tokens=10,
    min_length=1,
    length_penalty=-1,
    return_dict_in_generate=True,
    output_scores=True,
)
output_text = self.processor.batch_decode(
    outputs.sequences, skip_special_tokens=True
)
output_scores = torch.exp(outputs.sequences_scores).tolist()

but I got this error:

/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343962757/work/torch/csrc/utils/tensor_new.cpp:245.)
  return torch.tensor(value)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
/opt/conda/conda-bld/pytorch_1682343962757/work/aten/src/ATen/native/cuda/Indexing.cu:1093: indexSelectSmallIndex: block: [92,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1682343962757/work/aten/src/ATen/native/cuda/Indexing.cu:1093: indexSelectSmallIndex: block: [92,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
(omitted...)
/opt/conda/conda-bld/pytorch_1682343962757/work/aten/src/ATen/native/cuda/Indexing.cu:1093: indexSelectSmallIndex: block: [92,0,0], thread: [94,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1682343962757/work/aten/src/ATen/native/cuda/Indexing.cu:1093: indexSelectSmallIndex: block: [92,0,0], thread: [95,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "main.py", line 166, in <module>
    main()
  File "main.py", line 102, in main
    sub_questions, _ = decomposer(images, text_inputs)
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ywjang/LBA_LAVIS_uncertainty_v2/models/model.py", line 301, in forward
    outputs = self.model.generate(
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/transformers/generation/utils.py", line 1953, in generate
    result = self._beam_search(
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/transformers/generation/utils.py", line 3011, in _beam_search
    model_kwargs["past_key_values"] = self._temporary_reorder_cache(
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/transformers/generation/utils.py", line 2756, in _temporary_reorder_cache
    past_key_values = self._reorder_cache(past_key_values, beam_idx)
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/transformers/models/video_llava/modeling_video_llava.py", line 689, in _reorder_cache
    return self.language_model._reorder_cache(*args, **kwargs)
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1300, in _reorder_cache
    tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  File "/home/ywjang/miniconda3/envs/LBA_uncertainty_v2/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1300, in <genexpr>
    tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

What's the problem?

@zucchini-nlp
Copy link
Author

@greeksharifa yes, beam search is currently broken on latest versions and I have a fix for it, will be merged soon. But it should be working for older version, I just tried with the below script and got no errors:

import av
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

model_name = "LanguageBind/Video-LLaVA-7B-hf"
processor = VideoLlavaProcessor.from_pretrained(model_name)
model = VideoLlavaForConditionalGeneration.from_pretrained(
    model_name, 
    device_map="auto",
    attn_implementation=None,
)

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
videos = read_video_pyav(container, indices)

inputs = processor(videos=videos, text="USER: <video>\nWhat do you see here? ASSISTANT:", return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    num_beams=5,
    max_new_tokens=40,
    min_length=1,
    length_penalty=-1,
    return_dict_in_generate=True,
    output_scores=True,
)

output_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
print(output_text)
output_scores = torch.exp(outputs.sequences_scores).tolist()

I got:

  • transformers version: 4.42.0
  • Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • GPU type: NVIDIA A100-SXM4-80GB

If the above script works for you but fails with your video/text inputs, can you share a fully reproducible code pls. You can upload your video to the hub or send a link to it.

@greeksharifa
Copy link

greeksharifa commented Aug 1, 2024

@greeksharifa yes, beam search is currently broken on latest versions and I have a fix for it, will be merged soon. But it should be working for older version, I just tried with the below script and got no errors:

import av
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

model_name = "LanguageBind/Video-LLaVA-7B-hf"
processor = VideoLlavaProcessor.from_pretrained(model_name)
model = VideoLlavaForConditionalGeneration.from_pretrained(
    model_name, 
    device_map="auto",
    attn_implementation=None,
)

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
videos = read_video_pyav(container, indices)

inputs = processor(videos=videos, text="USER: <video>\nWhat do you see here? ASSISTANT:", return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    num_beams=5,
    max_new_tokens=40,
    min_length=1,
    length_penalty=-1,
    return_dict_in_generate=True,
    output_scores=True,
)

output_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
print(output_text)
output_scores = torch.exp(outputs.sequences_scores).tolist()

I got:

  • transformers version: 4.42.0
  • Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • GPU type: NVIDIA A100-SXM4-80GB

If the above script works for you but fails with your video/text inputs, can you share a fully reproducible code pls. You can upload your video to the hub or send a link to it.

Thanks for your quick answer. @zucchini-nlp

I tried the above code and got no errors but a very strange answer like this:

[2024-08-01 08:08:43,282] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 3/3 [00:14<00:00,  4.71s/it]
/opt/conda/lib/python3.10/site-packages/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:245.)
  return torch.tensor(value)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
['USER: \nWhat do you see here? ASSISTANT:Mediaengoymbol abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abb abbcr']

I guess the tokenizer or some special tokens is broken, but it's not certain.

I used transformers==4.42, torch==2.0.1 (follows readme ), python==3.10.

in addition) GPU: A6000
and in torch 2.3.0+cu121 also same error occurs.

@zucchini-nlp
Copy link
Author

@greeksharifa right, found that video llava had a bug that was related to new cache format. I am opening a PR (huggingface/transformers#32417) to fix this, and you will be able to run beam search when it's merged. Make sure to update transformers with !pip install --upgrade git+https://github.com/huggingface/transformers.git

@IsabelJimenez99
Copy link

Hi, I am trying videlloava to get the types of physical interactions it can see in a video, but it is not able to answer correctly.

Is there any way (without finetuning, using the prompt directly) to pass it 2 or 3 videos with its explanation and the type of interaction and a fourth video asking it about the type of interaction? So that the model has the visual context of what the type of interaction looks like.
Or does the model directly analyse each video independently of the rest, without having a ‘memory’ of its response?
Thanks!!

@zucchini-nlp
Copy link
Author

@IsabelJimenez99 hey! AFAIK video-llava was not trained in few-shot setting so we can't be sure that it will be pick up from a few examples and continue in the same format, You can try out experimenting and maybe ask authors if they can share any insights on few-shot inference

In terms of implementation, transformers supports multi-turn chat formatted input so it should be no problem to run generation :)

@IsabelJimenez99
Copy link

Okey, thanks so much!!

@BalloutAI
Copy link

BalloutAI commented Nov 8, 2024

@zucchini-nlp Hello, and thanks for your hard work!
What is the best way to extract features of videos using video-llava with the transformers library?

@sterzhang
Copy link

@greeksharifa yes, beam search is currently broken on latest versions and I have a fix for it, will be merged soon. But it should be working for older version, I just tried with the below script and got no errors:

import av
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

model_name = "LanguageBind/Video-LLaVA-7B-hf"
processor = VideoLlavaProcessor.from_pretrained(model_name)
model = VideoLlavaForConditionalGeneration.from_pretrained(
    model_name, 
    device_map="auto",
    attn_implementation=None,
)

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
videos = read_video_pyav(container, indices)

inputs = processor(videos=videos, text="USER: <video>\nWhat do you see here? ASSISTANT:", return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    num_beams=5,
    max_new_tokens=40,
    min_length=1,
    length_penalty=-1,
    return_dict_in_generate=True,
    output_scores=True,
)

output_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
print(output_text)
output_scores = torch.exp(outputs.sequences_scores).tolist()

I got:

  • transformers version: 4.42.0
  • Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • GPU type: NVIDIA A100-SXM4-80GB

If the above script works for you but fails with your video/text inputs, can you share a fully reproducible code pls. You can upload your video to the hub or send a link to it.

Hello, do you know this error occur when running this code?
RuntimeError: The size of tensor a (2074) must match the size of tensor b (19) at non-singleton dimension 3

@BalloutAI
Copy link

@zucchini-nlp Hello, and thanks for your hard work! What is the best way to extract features of videos using video-llava with the transformers library?

Does this code make sense?

def extract_vision_features(self, video, text=None):
    """Extract features from the vision encoder"""

    text = "USER: <video>\n What do you see in this video? ASSISTANT:"
        
    # Process inputs
    inputs = self.processor(videos=video, text=text, return_tensors="pt")
    inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        # Get the video frames tensor
        video_frames = inputs['pixel_values_videos']  

        batch_size, num_frames, num_channels, height, width = video_frames.shape
        
        
        video_frames = video_frames.view(batch_size * num_frames, num_channels, height, width)
        
        # Process through video tower
        vision_features = self.model.video_tower(pixel_values=video_frames)
        
        # Reshape back to include frames dimension
        last_hidden_state = vision_features.last_hidden_state.view(
            batch_size, num_frames, -1, vision_features.last_hidden_state.size(-1)
        )
        print("last_hidden_state", last_hidden_state.shape)
        
        features = {
            'last_hidden_states': last_hidden_state
        }
        
        return features

def get_aggregated_features(self, video, text=None):
    """Get single feature vector for whole video"""
    vision_features = self.extract_vision_features(video, text)

    # Mean pool over frames and patches
    aggregated = vision_features['last_hidden_states'].mean(dim=[1, 2])

    return aggregated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests