Skip to content

Commit a5bd1e3

Browse files
authored
Merge pull request #59 from ai-forever/dev_alisa
merge pull for pllava filter
2 parents 2b56db2 + cf9a179 commit a5bd1e3

22 files changed

+3532
-21
lines changed

DPF/filters/utils/fp16_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def forward(self, *inputs, **kwargs):
5757
def state_dict(self, destination=None, prefix="", keep_vars=False):
5858
return self.module.state_dict(destination, prefix, keep_vars)
5959

60-
def load_state_dict(self, state_dict, strict=True):
60+
def load_state_dict(self, state_dict, strict=True): # type: ignore
6161
self.module.load_state_dict(state_dict, strict=strict)
6262

6363
def get_param(self, item):
6464
return self.module.get_param(item)
6565

66-
def to(self, device, *args, **kwargs):
66+
def to(self, device, *args, **kwargs): # type: ignore
6767
self.module.to(device)
6868
return super().to(device, *args, **kwargs)

DPF/filters/videos/image_filter_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def preprocess_data(
6868
frame = iio.imread(io.BytesIO(video_bytes), index=frame_index, plugin="pyav")
6969

7070
buff = io.BytesIO()
71-
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95) # type: ignore
71+
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95)
7272
modality2data['image'] = buff.getvalue()
7373
metadata[self.image_filter.key_column] = ''
7474
return key, self.image_filter.preprocess_data(modality2data, metadata)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import os
2+
from io import BytesIO
3+
from typing import Any, Optional
4+
5+
import numpy as np
6+
import torch
7+
import torchvision
8+
from decord import VideoReader, cpu
9+
from huggingface_hub import snapshot_download
10+
from PIL import Image
11+
12+
from DPF.filters.videos.video_filter import VideoFilter
13+
from DPF.types import ModalityToDataMapping
14+
15+
from .pllava_filter_core.tasks.eval.eval_utils import conv_templates
16+
from .pllava_filter_core.tasks.eval.model_utils import load_pllava
17+
18+
19+
def get_index(num_frames: int, num_segments: int) -> np.ndarray[Any, Any]:
20+
seg_size = float(num_frames - 1) / num_segments
21+
start = int(seg_size / 2)
22+
return np.array([
23+
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
24+
])
25+
26+
27+
def load_video(video_bytes: BytesIO, num_segments: int = 8, return_msg: bool = False, num_frames: int = 16, resolution: int = 336) -> Any:
28+
transforms = torchvision.transforms.Resize(size=resolution)
29+
vr = VideoReader(video_bytes, ctx=cpu(0), num_threads=1)
30+
num_frames = len(vr)
31+
frame_indices = get_index(num_frames, num_segments)
32+
images_group = []
33+
for frame_index in frame_indices:
34+
img = Image.fromarray(vr[frame_index].asnumpy())
35+
images_group.append(transforms(img))
36+
if return_msg:
37+
fps = float(vr.get_avg_fps())
38+
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
39+
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
40+
return images_group, msg
41+
else:
42+
return images_group
43+
44+
45+
class PllavaFilter(VideoFilter):
46+
"""
47+
Pllava inference class to get captions for videos.
48+
More info about the model here: https://pllava.github.io
49+
"""
50+
def __init__(
51+
self,
52+
model_path: str,
53+
weights_path: str,
54+
weights_dir: str,
55+
prompt: str = "short",
56+
prompts: Optional[dict[str, str]] = None,
57+
do_sample: bool = True,
58+
batch_size: int = 16,
59+
conv_mode: str = 'eval_vcg_llavanext',
60+
device: str = "cuda:0",
61+
workers: int = 16,
62+
num_frames: int = 32,
63+
max_new_tokens: int = 100,
64+
num_segments: int = 32,
65+
resolution: int = 672,
66+
temperature: float = 0.1,
67+
use_lora: bool = True,
68+
lora_alpha: int = 4,
69+
pbar: bool = True,
70+
_pbar_position: int = 0,
71+
use_multi_gpus: bool = False,
72+
use_cache: bool = True,
73+
):
74+
super().__init__(pbar, _pbar_position)
75+
self.weights_dir = weights_dir
76+
self.max_new_tokens = max_new_tokens
77+
self.conv_mode = conv_mode
78+
self.use_lora = use_lora
79+
self.do_sample = do_sample
80+
self.lora_alpha = lora_alpha
81+
self.weights_path = weights_path
82+
self.batch_size = batch_size
83+
self.num_segments = batch_size
84+
self.num_workers = workers
85+
self.device = device
86+
self.prompt_to_use = prompt
87+
self.temperature = temperature
88+
self.resolution = resolution
89+
self.num_segments = num_segments
90+
self.num_frames = num_frames
91+
self.use_cache = use_cache
92+
self.use_multi_gpus = use_multi_gpus
93+
94+
self.model_name = model_path.split('/')[-1]
95+
96+
if prompts is None:
97+
self.prompts = {
98+
'detailed_video': 'Please provide a caption for this image. Speak confidently and describe everything clearly. Do not lie and describe only what you can see',
99+
'pixart': 'Describe this image and its style in a very detailed manner',
100+
'short': 'Describe this image very shortly in 1-2 short sentences',
101+
'short-video': 'Describe this video very shortly in 1-2 short sentences. Describe what is happening in this video.'
102+
}
103+
else:
104+
self.prompts = prompts
105+
106+
self.input_ids = self.prompts[self.prompt_to_use]
107+
108+
self.conv = conv_templates[self.conv_mode].copy() # type: ignore
109+
self.conv.user_query(self.input_ids, is_mm=True)
110+
self.prompt = self.conv.get_prompt()
111+
112+
if not os.path.exists(weights_path):
113+
read_token = '...'
114+
local_dir = model_path.replace('ermu2001', 'weights')
115+
snapshot_download(
116+
model_path,
117+
local_dir=local_dir,
118+
repo_type='model',
119+
local_dir_use_symlinks=True,
120+
token=read_token,
121+
)
122+
123+
self.model, self.processor = load_pllava(
124+
self.weights_path,
125+
self.num_frames,
126+
use_lora=self.use_lora,
127+
weight_dir=self.weights_dir,
128+
lora_alpha=self.lora_alpha,
129+
use_multi_gpus=self.use_multi_gpus
130+
) # type: ignore
131+
132+
if not self.use_multi_gpus:
133+
self.model = self.model.to(self.device)
134+
135+
136+
@property
137+
def result_columns(self) -> list[str]:
138+
return [f"caption {self.model_name} prompt {self.prompt_to_use}"]
139+
140+
@property
141+
def dataloader_kwargs(self) -> dict[str, Any]:
142+
return {
143+
"num_workers": self.num_workers,
144+
"batch_size": self.batch_size,
145+
"drop_last": False,
146+
}
147+
148+
def preprocess_data(
149+
self,
150+
modality2data: ModalityToDataMapping,
151+
metadata: dict[str, Any]
152+
) -> Any:
153+
key = metadata[self.key_column]
154+
video_file = BytesIO(modality2data['video'])
155+
video_file, _ = load_video(video_file, num_segments=self.num_segments, return_msg=True, resolution=self.resolution)
156+
return key, video_file
157+
158+
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
159+
df_batch_labels = self._get_dict_from_schema()
160+
keys, video_tensors = list(zip(*batch))
161+
input_ids_batch = [self.prompt] * len(video_tensors)
162+
inputs = self.processor(text=input_ids_batch, images=video_tensors, return_tensors="pt")
163+
inputs = inputs.to(self.model.device)
164+
with torch.no_grad():
165+
output_token = self.model.generate(
166+
**inputs,
167+
media_type='video',
168+
do_sample=self.do_sample,
169+
max_new_tokens=self.max_new_tokens,
170+
temperature=self.temperature,
171+
use_cache = self.use_cache
172+
)
173+
output_texts = self.processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=True)
174+
split_tag = self.conv.roles[-1]
175+
bug_split_tag = "<|im_start|> assistant\n"
176+
all_outputs: list[Optional[str]] = []
177+
for output_text in output_texts:
178+
output_text = output_text.split(split_tag)[-1].split(bug_split_tag)[-1]
179+
ending = self.conv.sep if isinstance(self.conv.sep, str) else self.conv.sep[1]
180+
output_text = output_text.removesuffix(ending).strip()
181+
all_outputs.append(output_text)
182+
df_batch_labels[self.schema[1]].extend(all_outputs)
183+
df_batch_labels[self.key_column].extend(keys)
184+
return df_batch_labels
185+
186+
187+
class Pllava13bFilter(PllavaFilter):
188+
def __init__(self, **kwargs: Any) -> None:
189+
model_path: str = 'ermu2001/pllava-13b'
190+
weights_path: str = 'weights/pllava-13b'
191+
weights_dir: str = 'weights/pllava-13b'
192+
193+
super().__init__(model_path=model_path, weights_path=weights_path, weights_dir=weights_dir, prompts=None, **kwargs)

DPF/filters/videos/pllava_filter_core/models/__init__.py

Whitespace-only changes.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
15+
from typing import TYPE_CHECKING
16+
17+
_import_structure = {"configuration_pllava": ["PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "PllavaConfig"]}
18+
19+
try:
20+
if not is_torch_available():
21+
raise OptionalDependencyNotAvailable()
22+
except OptionalDependencyNotAvailable:
23+
pass
24+
else:
25+
_import_structure["modeling_pllava"] = [
26+
"PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
27+
"PllavaForConditionalGeneration",
28+
"PllavaPreTrainedModel",
29+
]
30+
_import_structure["processing_pllava"] = ["PllavaProcessor"]
31+
32+
33+
if TYPE_CHECKING:
34+
from .configuration_pllava import PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, PllavaConfig
35+
36+
try:
37+
if not is_torch_available():
38+
raise OptionalDependencyNotAvailable()
39+
except OptionalDependencyNotAvailable:
40+
pass
41+
else:
42+
from .modeling_pllava import (
43+
PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
44+
PllavaForConditionalGeneration,
45+
PllavaPreTrainedModel,
46+
)
47+
from .processing_pllava import PllavaProcessor
48+
49+
50+
else:
51+
import sys
52+
53+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

0 commit comments

Comments
 (0)