@@ -37,7 +37,7 @@ def check_caption(caption: str) -> Optional[str]:
3737 if max (sentences_dict .values ()) == 1 :
3838 return caption
3939 else :
40- return None
40+ return ""
4141
4242
4343class VideoLLaVAFilter (VideoFilter ):
@@ -52,7 +52,7 @@ def __init__(
5252 model_base : Optional [str ] = None ,
5353 cache_path : str = "cache_dir" ,
5454 prompt : str = "detailed_video" ,
55- temperature : float = 0.2 ,
55+ temperature : float = 0.8 ,
5656 max_new_tokens : int = 1024 ,
5757 load_4bit : bool = False ,
5858 load_8bit : bool = False ,
@@ -142,13 +142,14 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
142142 do_sample = True if self .temperature > 0 else False ,
143143 temperature = self .temperature ,
144144 max_new_tokens = self .max_new_tokens ,
145+ num_beams = 1 ,
146+ no_repeat_ngram_size = 2 ,
145147 use_cache = True ,
146148 stopping_criteria = [self .stopping_criteria ])
147149
148150 all_outputs : list [Optional [str ]] = []
149151 for i in range (output_ids .shape [0 ]):
150152 caption = self .tokenizer .decode (output_ids [i , self .input_ids .shape [1 ]:]).strip ().split ('</s>' )[0 ]
151- all_outputs .append (caption )
152153 all_outputs .append (check_caption (caption ))
153154 df_batch_labels [self .schema [1 ]].extend (all_outputs )
154155 df_batch_labels [self .key_column ].extend (keys )
0 commit comments