Skip to content

Commit 488a46f

Browse files
authored
Merge pull request #45 from ai-forever/kirillova/lita_video_captioner
feat: add LITA video captioning filter
2 parents 4d06df9 + 68cd4f1 commit 488a46f

File tree

6 files changed

+429
-44
lines changed

6 files changed

+429
-44
lines changed

DPF/filters/videos/lita_filter.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import os
2+
from io import BytesIO
3+
from typing import Any, Optional
4+
5+
import gdown
6+
import torch
7+
from lita.constants import (
8+
DEFAULT_IM_END_TOKEN,
9+
DEFAULT_IM_START_TOKEN,
10+
DEFAULT_IMAGE_TOKEN,
11+
IMAGE_TOKEN_INDEX,
12+
)
13+
from lita.model.builder import load_pretrained_model
14+
from lita.utils import load_video
15+
from llava.conversation import SeparatorStyle, conv_templates
16+
from llava.mm_utils import (
17+
KeywordsStoppingCriteria,
18+
get_model_name_from_path,
19+
tokenizer_image_token,
20+
)
21+
22+
from DPF.types import ModalityToDataMapping
23+
24+
from .video_filter import VideoFilter
25+
26+
try:
27+
from torch.utils.data.dataloader import default_collate
28+
except ImportError:
29+
from torch.utils.data import default_collate
30+
31+
32+
def disable_torch_init() -> None:
33+
"""
34+
Disable the redundant torch default initialization to accelerate model creation.
35+
"""
36+
torch.nn.Linear.reset_parameters = lambda self: None # type: ignore
37+
torch.nn.LayerNorm.reset_parameters = lambda self: None # type: ignore
38+
39+
40+
class LITAFilter(VideoFilter):
41+
"""
42+
LITA inference class to get captions for auto-labeling videos.
43+
More info about the model here: https://github.com/NVlabs/LITA
44+
"""
45+
def __init__(
46+
self,
47+
weights_path: str = "./lita-vicuna-v1-3-13b-finetune",
48+
model_base: Optional[str] = None,
49+
prompt: str = "detailed_video",
50+
temperature: float = 0.2,
51+
max_new_tokens: int = 1024,
52+
load_4bit: bool = False,
53+
load_8bit: bool = False,
54+
device: str = "cuda:0",
55+
workers: int = 16,
56+
batch_size: int = 8,
57+
pbar: bool = True,
58+
_pbar_position: int = 0
59+
):
60+
super().__init__(pbar, _pbar_position)
61+
self.model_name = get_model_name_from_path(weights_path)
62+
self.prompt_to_use = prompt
63+
prompt_templates = {
64+
'detailed_video': 'Describe this video and its style in a very detailed manner',
65+
'short_video': 'Describe this video and its style briefly'
66+
}
67+
68+
self.num_workers = workers
69+
self.batch_size = batch_size
70+
self.device = device
71+
72+
self.inp = prompt_templates[self.prompt_to_use]
73+
self.temperature = temperature
74+
self.max_new_tokens = max_new_tokens
75+
76+
weights_url = "https://drive.google.com/drive/folders/1-P7p-tq5aXZzSoefEJx4PSFKH8jt8KWy"
77+
if not os.path.exists(weights_path):
78+
gdown.download_folder(weights_url)
79+
80+
disable_torch_init()
81+
82+
pretrainers = load_pretrained_model(weights_path, model_base, self.model_name, load_8bit, load_4bit)
83+
self.tokenizer, self.model, self.processor, self.context_len = pretrainers
84+
85+
self.conv_mode = "llava_v1"
86+
self.conv = conv_templates[self.conv_mode].copy()
87+
88+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + self.inp
89+
self.conv.append_message(self.conv.roles[0], inp)
90+
self.conv.append_message(self.conv.roles[1], None)
91+
prompt = self.conv.get_prompt()
92+
self.input_ids = tokenizer_image_token(
93+
prompt,
94+
self.tokenizer,
95+
IMAGE_TOKEN_INDEX,
96+
return_tensors='pt'
97+
).unsqueeze(0).to(self.device)
98+
stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
99+
keywords = [stop_str]
100+
self.stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, self.input_ids)
101+
102+
@property
103+
def result_columns(self) -> list[str]:
104+
return [f"caption {self.model_name} prompt {self.prompt_to_use}"]
105+
106+
@property
107+
def dataloader_kwargs(self) -> dict[str, Any]:
108+
return {
109+
"num_workers": self.num_workers,
110+
"batch_size": self.batch_size,
111+
"drop_last": False,
112+
}
113+
114+
def preprocess_data(
115+
self,
116+
modality2data: ModalityToDataMapping,
117+
metadata: dict[str, Any]
118+
) -> Any:
119+
key = metadata[self.key_column]
120+
video_file = BytesIO(modality2data['video'])
121+
video_file = load_video(video_file, self.processor, self.model.config.num_frames).unsqueeze(0).half()
122+
return key, video_file
123+
124+
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
125+
df_batch_labels = self._get_dict_from_schema()
126+
127+
keys, video_tensors = list(zip(*batch))
128+
129+
video_tensors = default_collate(video_tensors).to(self.device) # type: ignore
130+
input_ids_batch = self.input_ids.repeat_interleave(video_tensors.shape[0], 0).to(self.device) # type: ignore
131+
132+
with torch.inference_mode():
133+
output_ids = self.model.generate(
134+
input_ids_batch,
135+
images=video_tensors[:, 0], # type: ignore
136+
do_sample=True,
137+
temperature=self.temperature,
138+
top_p=0.85,
139+
num_beams=1,
140+
max_new_tokens=self.max_new_tokens,
141+
use_cache=True
142+
)
143+
144+
all_outputs: list[Optional[str]] = []
145+
for i in range(output_ids.shape[0]):
146+
caption = self.tokenizer.decode(output_ids[i, self.input_ids.shape[1]:]).strip().split('</s>')[0]
147+
all_outputs.append(caption)
148+
df_batch_labels[self.schema[1]].extend(all_outputs)
149+
df_batch_labels[self.key_column].extend(keys)
150+
return df_batch_labels
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
def get_weights_path_from_url(url, md5sum=None):
2+
"""Get weights path from WEIGHT_HOME, if not exists,
3+
download it from url.
4+
5+
Args:
6+
url (str): download url
7+
md5sum (str): md5 sum of download package
8+
9+
Returns:
10+
str: a local path to save downloaded weights.
11+
12+
Examples:
13+
.. code-block:: python
14+
15+
from hapi.download import get_weights_path_from_url
16+
17+
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
18+
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
19+
20+
"""
21+
path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
22+
return path
23+
24+
25+
def _map_path(url, root_dir):
26+
# parse path after download under root_dir
27+
fname = osp.split(url)[-1]
28+
fpath = fname
29+
return osp.join(root_dir, fpath)
30+
31+
32+
def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
33+
""" Download from given url to root_dir.
34+
if file or directory specified by url is exists under
35+
root_dir, return the path directly, otherwise download
36+
from url and decompress it, return the path.
37+
38+
Args:
39+
url (str): download url
40+
root_dir (str): root dir for downloading, it should be
41+
WEIGHTS_HOME or DATASET_HOME
42+
md5sum (str): md5 sum of download package
43+
44+
Returns:
45+
str: a local path to save downloaded models & weights & datasets.
46+
"""
47+
assert is_url(url), "downloading from {} not a url".format(url)
48+
# parse path after download to decompress under root_dir
49+
fullpath = _map_path(url, root_dir)
50+
51+
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
52+
logger.info("Found {}".format(fullpath))
53+
else:
54+
if ParallelEnv().local_rank == 0:
55+
fullpath = _download(url, root_dir, md5sum)
56+
else:
57+
while not os.path.exists(fullpath):
58+
time.sleep(1)
59+
return fullpath
60+
61+
62+
def _download(url, path, md5sum=None):
63+
"""
64+
Download from url, save to path.
65+
66+
url (str): download url
67+
path (str): download to given path
68+
"""
69+
if not osp.exists(path):
70+
os.makedirs(path)
71+
72+
fname = osp.split(url)[-1]
73+
fullname = osp.join(path, fname)
74+
retry_cnt = 0
75+
76+
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
77+
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
78+
retry_cnt += 1
79+
else:
80+
raise RuntimeError("Download from {} failed. "
81+
"Retry limit reached".format(url))
82+
83+
logger.info("Downloading {} from {}".format(fname, url))
84+
85+
req = requests.get(url, stream=True)
86+
if req.status_code != 200:
87+
raise RuntimeError("Downloading from {} failed with code "
88+
"{}!".format(url, req.status_code))
89+
90+
# For protecting download interupted, download to
91+
# tmp_fullname firstly, move tmp_fullname to fullname
92+
# after download finished
93+
tmp_fullname = fullname + "_tmp"
94+
total_size = req.headers.get('content-length')
95+
with open(tmp_fullname, 'wb') as f:
96+
if total_size:
97+
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
98+
for chunk in req.iter_content(chunk_size=1024):
99+
f.write(chunk)
100+
pbar.update(1)
101+
else:
102+
for chunk in req.iter_content(chunk_size=1024):
103+
if chunk:
104+
f.write(chunk)
105+
shutil.move(tmp_fullname, fullname)
106+
107+
return fullname
108+
109+
110+
def _md5check(fullname, md5sum=None):
111+
if md5sum is None:
112+
return True
113+
114+
logger.info("File {} md5 checking...".format(fullname))
115+
md5 = hashlib.md5()
116+
with open(fullname, 'rb') as f:
117+
for chunk in iter(lambda: f.read(4096), b""):
118+
md5.update(chunk)
119+
calc_md5sum = md5.hexdigest()
120+
121+
if calc_md5sum != md5sum:
122+
logger.info("File {} md5 check failed, {}(calc) != "
123+
"{}(base)".format(fullname, calc_md5sum, md5sum))
124+
return False
125+
return True

docs/filters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ List of implemented filters:
2929
- [GunnarFarnebackFilter](../DPF/filters/videos/farneback_filter.py) - computes flow scores using Farneback's algorithm
3030
- [RAFTOpticalFlowFilter](../DPF/filters/videos/raft_filter.py) - computes flow scores using [RAFT](https://github.com/princeton-vl/RAFT) model
3131
- [VideoLLaVAFilter](../DPF/filters/videos/video_llava_filter.py) - captioning videos using Video-LLaVA
32+
- [LITAFilter](../DPF/filters/videos/lita_filter.py) - captioning videos using [LITA model](https://github.com/NVlabs/LITA)
3233

3334
### Datafilter
3435

0 commit comments

Comments
 (0)