Skip to content

Commit 6998a3e

Browse files
authored
Merge pull request #40 from ai-forever/kirillova/video_llava_fix_captions
fix: remove errors with captions
2 parents c79d77d + f8a0ad4 commit 6998a3e

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

DPF/filters/videos/video_llava_filter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def check_caption(caption: str) -> Optional[str]:
3737
if max(sentences_dict.values()) == 1:
3838
return caption
3939
else:
40-
return None
40+
return ""
4141

4242

4343
class VideoLLaVAFilter(VideoFilter):
@@ -52,7 +52,7 @@ def __init__(
5252
model_base: Optional[str] = None,
5353
cache_path: str = "cache_dir",
5454
prompt: str = "detailed_video",
55-
temperature: float = 0.2,
55+
temperature: float = 0.8,
5656
max_new_tokens: int = 1024,
5757
load_4bit: bool = False,
5858
load_8bit: bool = False,
@@ -142,13 +142,14 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
142142
do_sample=True if self.temperature > 0 else False,
143143
temperature=self.temperature,
144144
max_new_tokens=self.max_new_tokens,
145+
num_beams=1,
146+
no_repeat_ngram_size=2,
145147
use_cache=True,
146148
stopping_criteria=[self.stopping_criteria])
147149

148150
all_outputs: list[Optional[str]] = []
149151
for i in range(output_ids.shape[0]):
150152
caption = self.tokenizer.decode(output_ids[i, self.input_ids.shape[1]:]).strip().split('</s>')[0]
151-
all_outputs.append(caption)
152153
all_outputs.append(check_caption(caption))
153154
df_batch_labels[self.schema[1]].extend(all_outputs)
154155
df_batch_labels[self.key_column].extend(keys)

0 commit comments

Comments
 (0)