|
| 1 | +import os |
| 2 | +from io import BytesIO |
| 3 | +from typing import Any, Optional |
| 4 | + |
| 5 | +import gdown |
| 6 | +import torch |
| 7 | +from lita.constants import ( |
| 8 | + DEFAULT_IM_END_TOKEN, |
| 9 | + DEFAULT_IM_START_TOKEN, |
| 10 | + DEFAULT_IMAGE_TOKEN, |
| 11 | + IMAGE_TOKEN_INDEX, |
| 12 | +) |
| 13 | +from lita.model.builder import load_pretrained_model |
| 14 | +from lita.utils import load_video |
| 15 | +from llava.conversation import SeparatorStyle, conv_templates |
| 16 | +from llava.mm_utils import ( |
| 17 | + KeywordsStoppingCriteria, |
| 18 | + get_model_name_from_path, |
| 19 | + tokenizer_image_token, |
| 20 | +) |
| 21 | + |
| 22 | +from DPF.types import ModalityToDataMapping |
| 23 | + |
| 24 | +from .video_filter import VideoFilter |
| 25 | + |
| 26 | +try: |
| 27 | + from torch.utils.data.dataloader import default_collate |
| 28 | +except ImportError: |
| 29 | + from torch.utils.data import default_collate |
| 30 | + |
| 31 | + |
| 32 | +def disable_torch_init() -> None: |
| 33 | + """ |
| 34 | + Disable the redundant torch default initialization to accelerate model creation. |
| 35 | + """ |
| 36 | + torch.nn.Linear.reset_parameters = lambda self: None # type: ignore |
| 37 | + torch.nn.LayerNorm.reset_parameters = lambda self: None # type: ignore |
| 38 | + |
| 39 | + |
| 40 | +class LITAFilter(VideoFilter): |
| 41 | + """ |
| 42 | + LITA inference class to get captions for auto-labeling videos. |
| 43 | + More info about the model here: https://github.com/NVlabs/LITA |
| 44 | + """ |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + weights_path: str = "./lita-vicuna-v1-3-13b-finetune", |
| 48 | + model_base: Optional[str] = None, |
| 49 | + prompt: str = "detailed_video", |
| 50 | + temperature: float = 0.2, |
| 51 | + max_new_tokens: int = 1024, |
| 52 | + load_4bit: bool = False, |
| 53 | + load_8bit: bool = False, |
| 54 | + device: str = "cuda:0", |
| 55 | + workers: int = 16, |
| 56 | + batch_size: int = 8, |
| 57 | + pbar: bool = True, |
| 58 | + _pbar_position: int = 0 |
| 59 | + ): |
| 60 | + super().__init__(pbar, _pbar_position) |
| 61 | + self.model_name = get_model_name_from_path(weights_path) |
| 62 | + self.prompt_to_use = prompt |
| 63 | + prompt_templates = { |
| 64 | + 'detailed_video': 'Describe this video and its style in a very detailed manner', |
| 65 | + 'short_video': 'Describe this video and its style briefly' |
| 66 | + } |
| 67 | + |
| 68 | + self.num_workers = workers |
| 69 | + self.batch_size = batch_size |
| 70 | + self.device = device |
| 71 | + |
| 72 | + self.inp = prompt_templates[self.prompt_to_use] |
| 73 | + self.temperature = temperature |
| 74 | + self.max_new_tokens = max_new_tokens |
| 75 | + |
| 76 | + weights_url = "https://drive.google.com/drive/folders/1-P7p-tq5aXZzSoefEJx4PSFKH8jt8KWy" |
| 77 | + if not os.path.exists(weights_path): |
| 78 | + gdown.download_folder(weights_url) |
| 79 | + |
| 80 | + disable_torch_init() |
| 81 | + |
| 82 | + pretrainers = load_pretrained_model(weights_path, model_base, self.model_name, load_8bit, load_4bit) |
| 83 | + self.tokenizer, self.model, self.processor, self.context_len = pretrainers |
| 84 | + |
| 85 | + self.conv_mode = "llava_v1" |
| 86 | + self.conv = conv_templates[self.conv_mode].copy() |
| 87 | + |
| 88 | + inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + self.inp |
| 89 | + self.conv.append_message(self.conv.roles[0], inp) |
| 90 | + self.conv.append_message(self.conv.roles[1], None) |
| 91 | + prompt = self.conv.get_prompt() |
| 92 | + self.input_ids = tokenizer_image_token( |
| 93 | + prompt, |
| 94 | + self.tokenizer, |
| 95 | + IMAGE_TOKEN_INDEX, |
| 96 | + return_tensors='pt' |
| 97 | + ).unsqueeze(0).to(self.device) |
| 98 | + stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 |
| 99 | + keywords = [stop_str] |
| 100 | + self.stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, self.input_ids) |
| 101 | + |
| 102 | + @property |
| 103 | + def result_columns(self) -> list[str]: |
| 104 | + return [f"caption {self.model_name} prompt {self.prompt_to_use}"] |
| 105 | + |
| 106 | + @property |
| 107 | + def dataloader_kwargs(self) -> dict[str, Any]: |
| 108 | + return { |
| 109 | + "num_workers": self.num_workers, |
| 110 | + "batch_size": self.batch_size, |
| 111 | + "drop_last": False, |
| 112 | + } |
| 113 | + |
| 114 | + def preprocess_data( |
| 115 | + self, |
| 116 | + modality2data: ModalityToDataMapping, |
| 117 | + metadata: dict[str, Any] |
| 118 | + ) -> Any: |
| 119 | + key = metadata[self.key_column] |
| 120 | + video_file = BytesIO(modality2data['video']) |
| 121 | + video_file = load_video(video_file, self.processor, self.model.config.num_frames).unsqueeze(0).half() |
| 122 | + return key, video_file |
| 123 | + |
| 124 | + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: |
| 125 | + df_batch_labels = self._get_dict_from_schema() |
| 126 | + |
| 127 | + keys, video_tensors = list(zip(*batch)) |
| 128 | + |
| 129 | + video_tensors = default_collate(video_tensors).to(self.device) # type: ignore |
| 130 | + input_ids_batch = self.input_ids.repeat_interleave(video_tensors.shape[0], 0).to(self.device) # type: ignore |
| 131 | + |
| 132 | + with torch.inference_mode(): |
| 133 | + output_ids = self.model.generate( |
| 134 | + input_ids_batch, |
| 135 | + images=video_tensors[:, 0], # type: ignore |
| 136 | + do_sample=True, |
| 137 | + temperature=self.temperature, |
| 138 | + top_p=0.85, |
| 139 | + num_beams=1, |
| 140 | + max_new_tokens=self.max_new_tokens, |
| 141 | + use_cache=True |
| 142 | + ) |
| 143 | + |
| 144 | + all_outputs: list[Optional[str]] = [] |
| 145 | + for i in range(output_ids.shape[0]): |
| 146 | + caption = self.tokenizer.decode(output_ids[i, self.input_ids.shape[1]:]).strip().split('</s>')[0] |
| 147 | + all_outputs.append(caption) |
| 148 | + df_batch_labels[self.schema[1]].extend(all_outputs) |
| 149 | + df_batch_labels[self.key_column].extend(keys) |
| 150 | + return df_batch_labels |
0 commit comments