12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import json
16
+ import os
17
+ import pathlib
18
+ import queue
19
+ import time
20
+ from concurrent .futures import ThreadPoolExecutor
21
+
22
+ import av
15
23
import fire
16
- from transformers import LlavaNextVideoProcessor , LlavaNextVideoForConditionalGeneration
17
24
import numpy as np
18
25
import torch
19
- import time
26
+ from huggingface_hub import snapshot_download
27
+ from tqdm import tqdm
28
+ from transformers import LlavaNextVideoForConditionalGeneration , LlavaNextVideoProcessor
29
+
20
30
from accelerate import PartialState
21
- import os
22
- import av
23
- from huggingface_hub import hf_hub_download
24
- import json
25
- import queue
26
- from concurrent .futures import ThreadPoolExecutor
27
- import pathlib
31
+
28
32
29
33
START_TIME = time .strftime ("%Y%m%d_%H%M%S" )
30
34
DTYPE_MAP = {"fp32" : torch .float32 , "fp16" : torch .float16 , "bf16" : torch .bfloat16 }
33
37
"""
34
38
Example:
35
39
36
- accelerate launch llava_next_video.py
40
+ accelerate launch llava_next_video.py
37
41
"""
38
42
39
43
44
+ def save_results (output_queue : queue .Queue , output_dir : pathlib .Path ):
45
+ count = 0
46
+ while True :
47
+ try :
48
+ item = output_queue .get (timeout = 5 )
49
+ if item is None :
50
+ break
51
+ example_file = f"example_{ count } "
52
+ temp_dir = os .path .join (output_dir , example_file )
53
+
54
+ metadata = {
55
+ "prompt" : item [0 ],
56
+ "generated_answer" : item [1 ],
57
+ }
58
+ with open (temp_dir , "w" ) as f :
59
+ json .dump (metadata , f , indent = 4 )
60
+ count += 1
61
+
62
+ except queue .Empty :
63
+ continue
64
+
65
+
66
+ def get_batches (videos , prompts ):
67
+ batch_size = len (prompts )
68
+ num_batches = (len (videos ) + batch_size - 1 ) // batch_size
69
+ batches = []
70
+
71
+ for i in range (num_batches ):
72
+ start_index = i * batch_size
73
+ end_index = min ((i + 1 ) * batch_size , len (videos ))
74
+ batch = videos [start_index :end_index ]
75
+ batches .append (batch )
76
+
77
+ return batches
78
+
79
+
80
+ def read_video_pyav (container , indices ):
81
+ """
82
+ Decode the video with PyAV decoder.
83
+ Args:
84
+ container (`av.container.input.InputContainer`): PyAV container.
85
+ indices (`List[int]`): List of frame indices to decode.
86
+ Returns:
87
+ result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
88
+ """
89
+ frames = []
90
+ container .seek (0 )
91
+ start_index = indices [0 ]
92
+ end_index = indices [- 1 ]
93
+ for i , frame in enumerate (container .decode (video = 0 )):
94
+ if i > end_index :
95
+ break
96
+ if i >= start_index and i in indices :
97
+ frames .append (frame )
98
+ return np .stack ([x .to_ndarray (format = "rgb24" ) for x in frames ])
99
+
100
+
101
+ def process_videos (video_files ):
102
+ processed_videos = []
103
+ for video in video_files :
104
+ container = av .open (video )
105
+ total_frames = container .streams .video [0 ].frames
106
+ indices = np .arange (0 , total_frames , total_frames / 8 ).astype (int )
107
+ processed_video = read_video_pyav (container , indices )
108
+ processed_videos .append (processed_video )
109
+ return processed_videos
110
+
111
+
40
112
def main (
41
113
model_name : str = "llava-hf/LLaVA-NeXT-Video-7B-hf" ,
42
114
save_dir : str = "./evaluation/examples" ,
@@ -49,7 +121,7 @@ def main(
49
121
50
122
processor = LlavaNextVideoProcessor .from_pretrained (model_name )
51
123
model = LlavaNextVideoForConditionalGeneration .from_pretrained (
52
- model_name , torch_dtype = dtype [ DTYPE_MAP ], low_cpu_mem_usage = low_mem , device_map = distributed_state .device
124
+ model_name , torch_dtype = DTYPE_MAP [ dtype ], low_cpu_mem_usage = low_mem , device_map = distributed_state .device
53
125
)
54
126
55
127
if distributed_state .is_main_process :
@@ -59,14 +131,14 @@ def main(
59
131
else :
60
132
print (f"Directory '{ save_dir } ' already exists." )
61
133
62
- # Load the video as an np.array, sampling uniformly 8 frames (can sample more for longer videos)
63
- video_path = hf_hub_download (
64
- repo_id = "raushan-testing-hf/videos-test" , filename = "sample_demo_1.mp4" , repo_type = "dataset"
134
+ video_path = os .path .join (
135
+ snapshot_download (repo_id = "Wild-Heart/Disney-VideoGeneration-Dataset" , repo_type = "dataset" ), "videos"
65
136
)
66
- container = av .open (video_path )
67
- total_frames = container .streams .video [0 ].frames
68
- indices = np .arange (0 , total_frames , total_frames / 8 ).astype (int )
69
- video = read_video_pyav (container , indices )
137
+
138
+ video_files = [
139
+ os .path .join (video_path , f ) for f in os .listdir (video_path ) if os .path .isfile (os .path .join (video_path , f ))
140
+ ]
141
+ processed_videos = process_videos (video_files )
70
142
71
143
conversations = [
72
144
[
@@ -111,66 +183,23 @@ def main(
111
183
processor .apply_chat_template (conversations [i ], add_generation_prompt = True )
112
184
for i in range (0 , len (conversations ))
113
185
]
114
-
115
- def save_results (output_queue : queue .Queue , output_dir : pathlib .Path ):
116
- count = 0
117
- while True :
118
- try :
119
- item = output_queue .get (timeout = 5 )
120
- if item is None :
121
- break
122
- example_file = f"example_{ count } "
123
- temp_dir = os .path .join (output_dir , example_file )
124
-
125
- metadata = {
126
- "prompt" : item [0 ],
127
- "generated_answer" : item [1 ],
128
- }
129
- with open (temp_dir , "w" ) as f :
130
- json .dump (metadata , f , indent = 4 )
131
- count += 1
132
-
133
- except queue .Empty :
134
- continue
135
-
136
- distributed_state .num_processes = len (formatted_prompts )
186
+ batches = get_batches (processed_videos , formatted_prompts )
137
187
output_queue = queue .Queue ()
138
188
save_thread = ThreadPoolExecutor (max_workers = num_workers )
139
189
save_future = save_thread .submit (save_results , output_queue , save_dir )
140
-
141
- try :
142
- with distributed_state .split_between_processes (formatted_prompts ) as prompt :
143
- input = processor (text = prompt , videos = video , padding = True , return_tensors = "pt" ).to (model .device )
144
- output = model .generate (** input , max_new_tokens = 60 )
145
- generated_text = processor .decode (output [0 ][2 :], skip_special_tokens = True )
146
- output_queue .put ((prompt , generated_text ))
147
- finally :
148
- output_queue .put (None )
149
- save_thread .shutdown (wait = True )
190
+ for _ , batch in tqdm ( enumerate ( batches ), total = len ( batches )):
191
+ try :
192
+ with distributed_state .split_between_processes (formatted_prompts ) as prompt :
193
+ input = processor (text = prompt , videos = batch , padding = True , return_tensors = "pt" ).to (model .device )
194
+ output = model .generate (** input , max_new_tokens = 60 )
195
+ generated_text = processor .decode (output [0 ][2 :], skip_special_tokens = True )
196
+ output_queue .put ((prompt , generated_text ))
197
+ finally :
198
+ output_queue .put (None )
199
+ save_thread .shutdown (wait = True )
150
200
151
201
save_future .result ()
152
202
153
203
154
- def read_video_pyav (container , indices ):
155
- """
156
- Decode the video with PyAV decoder.
157
- Args:
158
- container (`av.container.input.InputContainer`): PyAV container.
159
- indices (`List[int]`): List of frame indices to decode.
160
- Returns:
161
- result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
162
- """
163
- frames = []
164
- container .seek (0 )
165
- start_index = indices [0 ]
166
- end_index = indices [- 1 ]
167
- for i , frame in enumerate (container .decode (video = 0 )):
168
- if i > end_index :
169
- break
170
- if i >= start_index and i in indices :
171
- frames .append (frame )
172
- return np .stack ([x .to_ndarray (format = "rgb24" ) for x in frames ])
173
-
174
-
175
204
if __name__ == "__main__" :
176
205
fire .Fire (main )
0 commit comments