Skip to content

Commit

Permalink
small fix for batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Oct 24, 2024
1 parent 8cd9376 commit 0a58faa
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/inference/distributed/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def save_results(output_queue: queue.Queue, output_dir: pathlib.Path):
continue


def get_batches(videos, prompts):
batch_size = len(prompts)
def get_batches(videos, batch_size):
num_batches = (len(videos) + batch_size - 1) // batch_size
batches = []

Expand Down Expand Up @@ -112,6 +111,7 @@ def process_videos(video_files):
def main(
model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf",
save_dir: str = "./evaluation/examples",
batch_size: int = 4,
dtype: str = "fp16",
num_workers: int = 1,
low_mem: bool = True,
Expand Down Expand Up @@ -183,7 +183,7 @@ def main(
processor.apply_chat_template(conversations[i], add_generation_prompt=True)
for i in range(0, len(conversations))
]
batches = get_batches(processed_videos, formatted_prompts)
batches = get_batches(processed_videos, batch_size)
output_queue = queue.Queue()
save_thread = ThreadPoolExecutor(max_workers=num_workers)
save_future = save_thread.submit(save_results, output_queue, save_dir)
Expand Down

0 comments on commit 0a58faa

Please sign in to comment.