Skip to content

Commit 8cd9376

Browse files
add videos dataset for example, refactor logic
1 parent 1c23c91 commit 8cd9376

File tree

1 file changed

+101
-72
lines changed

1 file changed

+101
-72
lines changed

examples/inference/distributed/llava_next_video.py

+101-72
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
16+
import os
17+
import pathlib
18+
import queue
19+
import time
20+
from concurrent.futures import ThreadPoolExecutor
21+
22+
import av
1523
import fire
16-
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
1724
import numpy as np
1825
import torch
19-
import time
26+
from huggingface_hub import snapshot_download
27+
from tqdm import tqdm
28+
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
29+
2030
from accelerate import PartialState
21-
import os
22-
import av
23-
from huggingface_hub import hf_hub_download
24-
import json
25-
import queue
26-
from concurrent.futures import ThreadPoolExecutor
27-
import pathlib
31+
2832

2933
START_TIME = time.strftime("%Y%m%d_%H%M%S")
3034
DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
@@ -33,10 +37,78 @@
3337
"""
3438
Example:
3539
36-
accelerate launch llava_next_video.py
40+
accelerate launch llava_next_video.py
3741
"""
3842

3943

44+
def save_results(output_queue: queue.Queue, output_dir: pathlib.Path):
45+
count = 0
46+
while True:
47+
try:
48+
item = output_queue.get(timeout=5)
49+
if item is None:
50+
break
51+
example_file = f"example_{count}"
52+
temp_dir = os.path.join(output_dir, example_file)
53+
54+
metadata = {
55+
"prompt": item[0],
56+
"generated_answer": item[1],
57+
}
58+
with open(temp_dir, "w") as f:
59+
json.dump(metadata, f, indent=4)
60+
count += 1
61+
62+
except queue.Empty:
63+
continue
64+
65+
66+
def get_batches(videos, prompts):
67+
batch_size = len(prompts)
68+
num_batches = (len(videos) + batch_size - 1) // batch_size
69+
batches = []
70+
71+
for i in range(num_batches):
72+
start_index = i * batch_size
73+
end_index = min((i + 1) * batch_size, len(videos))
74+
batch = videos[start_index:end_index]
75+
batches.append(batch)
76+
77+
return batches
78+
79+
80+
def read_video_pyav(container, indices):
81+
"""
82+
Decode the video with PyAV decoder.
83+
Args:
84+
container (`av.container.input.InputContainer`): PyAV container.
85+
indices (`List[int]`): List of frame indices to decode.
86+
Returns:
87+
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
88+
"""
89+
frames = []
90+
container.seek(0)
91+
start_index = indices[0]
92+
end_index = indices[-1]
93+
for i, frame in enumerate(container.decode(video=0)):
94+
if i > end_index:
95+
break
96+
if i >= start_index and i in indices:
97+
frames.append(frame)
98+
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
99+
100+
101+
def process_videos(video_files):
102+
processed_videos = []
103+
for video in video_files:
104+
container = av.open(video)
105+
total_frames = container.streams.video[0].frames
106+
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
107+
processed_video = read_video_pyav(container, indices)
108+
processed_videos.append(processed_video)
109+
return processed_videos
110+
111+
40112
def main(
41113
model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf",
42114
save_dir: str = "./evaluation/examples",
@@ -49,7 +121,7 @@ def main(
49121

50122
processor = LlavaNextVideoProcessor.from_pretrained(model_name)
51123
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
52-
model_name, torch_dtype=dtype[DTYPE_MAP], low_cpu_mem_usage=low_mem, device_map=distributed_state.device
124+
model_name, torch_dtype=DTYPE_MAP[dtype], low_cpu_mem_usage=low_mem, device_map=distributed_state.device
53125
)
54126

55127
if distributed_state.is_main_process:
@@ -59,14 +131,14 @@ def main(
59131
else:
60132
print(f"Directory '{save_dir}' already exists.")
61133

62-
# Load the video as an np.array, sampling uniformly 8 frames (can sample more for longer videos)
63-
video_path = hf_hub_download(
64-
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
134+
video_path = os.path.join(
135+
snapshot_download(repo_id="Wild-Heart/Disney-VideoGeneration-Dataset", repo_type="dataset"), "videos"
65136
)
66-
container = av.open(video_path)
67-
total_frames = container.streams.video[0].frames
68-
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
69-
video = read_video_pyav(container, indices)
137+
138+
video_files = [
139+
os.path.join(video_path, f) for f in os.listdir(video_path) if os.path.isfile(os.path.join(video_path, f))
140+
]
141+
processed_videos = process_videos(video_files)
70142

71143
conversations = [
72144
[
@@ -111,66 +183,23 @@ def main(
111183
processor.apply_chat_template(conversations[i], add_generation_prompt=True)
112184
for i in range(0, len(conversations))
113185
]
114-
115-
def save_results(output_queue: queue.Queue, output_dir: pathlib.Path):
116-
count = 0
117-
while True:
118-
try:
119-
item = output_queue.get(timeout=5)
120-
if item is None:
121-
break
122-
example_file = f"example_{count}"
123-
temp_dir = os.path.join(output_dir, example_file)
124-
125-
metadata = {
126-
"prompt": item[0],
127-
"generated_answer": item[1],
128-
}
129-
with open(temp_dir, "w") as f:
130-
json.dump(metadata, f, indent=4)
131-
count += 1
132-
133-
except queue.Empty:
134-
continue
135-
136-
distributed_state.num_processes = len(formatted_prompts)
186+
batches = get_batches(processed_videos, formatted_prompts)
137187
output_queue = queue.Queue()
138188
save_thread = ThreadPoolExecutor(max_workers=num_workers)
139189
save_future = save_thread.submit(save_results, output_queue, save_dir)
140-
141-
try:
142-
with distributed_state.split_between_processes(formatted_prompts) as prompt:
143-
input = processor(text=prompt, videos=video, padding=True, return_tensors="pt").to(model.device)
144-
output = model.generate(**input, max_new_tokens=60)
145-
generated_text = processor.decode(output[0][2:], skip_special_tokens=True)
146-
output_queue.put((prompt, generated_text))
147-
finally:
148-
output_queue.put(None)
149-
save_thread.shutdown(wait=True)
190+
for _, batch in tqdm(enumerate(batches), total=len(batches)):
191+
try:
192+
with distributed_state.split_between_processes(formatted_prompts) as prompt:
193+
input = processor(text=prompt, videos=batch, padding=True, return_tensors="pt").to(model.device)
194+
output = model.generate(**input, max_new_tokens=60)
195+
generated_text = processor.decode(output[0][2:], skip_special_tokens=True)
196+
output_queue.put((prompt, generated_text))
197+
finally:
198+
output_queue.put(None)
199+
save_thread.shutdown(wait=True)
150200

151201
save_future.result()
152202

153203

154-
def read_video_pyav(container, indices):
155-
"""
156-
Decode the video with PyAV decoder.
157-
Args:
158-
container (`av.container.input.InputContainer`): PyAV container.
159-
indices (`List[int]`): List of frame indices to decode.
160-
Returns:
161-
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
162-
"""
163-
frames = []
164-
container.seek(0)
165-
start_index = indices[0]
166-
end_index = indices[-1]
167-
for i, frame in enumerate(container.decode(video=0)):
168-
if i > end_index:
169-
break
170-
if i >= start_index and i in indices:
171-
frames.append(frame)
172-
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
173-
174-
175204
if __name__ == "__main__":
176205
fire.Fire(main)

0 commit comments

Comments
 (0)