Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions DPF/filters/images/llava34b_captioning_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import re
from typing import Any

import torch
from torchvision import transforms as T
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor

from DPF.filters.images.img_filter import ImageFilter
from DPF.types import ModalityToDataMapping
from DPF.utils import read_image_rgb_from_bytes


class Llava34b_Filter(ImageFilter):
"""
The filter implements a description of the images supplied to the input using a model llava-v1.6-34b-hf.
"""

def __init__(
self,
model_path: str = 'llava-hf/llava-v1.6-34b-hf',
workers: int = 16,
batch_size: int = 8,
prompt: str = 'detailed-long',
device: str = "cuda:0",
pbar: bool = True,
crop_size_x: int = 336,
crop_size_y: int = 336,
resize: int = 336,
_pbar_position: int = 0
):
super().__init__(pbar, _pbar_position)
self.batch_size = batch_size
self.num_workers = workers
self.device = device
self.crop_size_x = crop_size_x
self.crop_size_y = crop_size_y
self.resize = resize
self.model_path = model_path
self.prompt_to_use = prompt
prompts = {
'detailed-long': 'Please provide a caption for this image. Speak confidently and describe everything clearly. Do not lie and describe only what you can see',
'pixart': 'Describe this image and its style in a very detailed manner',
'short': 'Describe this image very shortly in 1-2 short sentences',
'short-video': 'Describe this video very shortly in 1-2 short sentences. Describe what is happening in this video.'
}
self.input_ids = prompts[self.prompt_to_use]
print(self.input_ids)
self.prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n" + f"{self.input_ids}" + "<|im_end|><|im_start|>assistant\n"
self.processor = LlavaNextProcessor.from_pretrained(model_path)
self.model = LlavaNextForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
device_map=self.device
)

@property
def result_columns(self) -> list[str]:
return [f"caption {self.model_path}"]

@property
def dataloader_kwargs(self) -> dict[str, Any]:
return {
"num_workers": self.num_workers,
"batch_size": self.batch_size,
"drop_last": False,
}

def preprocess_data(
self,
modality2data: ModalityToDataMapping,
metadata: dict[str, Any]
) -> Any:
key = metadata[self.key_column]
pil_img = read_image_rgb_from_bytes(
modality2data['image']).convert('RGB')
transform = T.Compose([
T.Resize(self.resize),
T.CenterCrop((self.crop_size_x,self.crop_size_y))
])
cropped_image = transform(pil_img)
return key, cropped_image

def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
df_batch_labels = self._get_dict_from_schema()
keys, images = list(zip(*batch))
prompts = [self.prompt for _ in range(self.batch_size)]
inputs = self.processor(prompts, list(
images), return_tensors="pt").to(self.device)
with torch.inference_mode():
output_ids = self.model.generate(
**inputs, max_new_tokens=512, use_cache=True)

all_outputs = []
for i in range(output_ids.shape[0]):
output = self.processor.decode(
output_ids[i], skip_special_tokens=True, clean_up_tokenization_spaces=True)
output = re.sub(r'.*?assistant', '', output, flags=re.DOTALL)
output = re.sub(r'\n', '', output, count=1)
all_outputs.append(output)

df_batch_labels[self.schema[1]].extend(all_outputs)
df_batch_labels[self.key_column].extend(keys)

return df_batch_labels
1 change: 1 addition & 0 deletions docs/filters.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ List of implemented filters:
- [BLIPCaptioningFilter](../DPF/filters/images/blip_captioning_filter.py) - captioning images using BLIP model
- [CLIPLabelsFilter](../DPF/filters/images/cliplabels_filter.py) - calculate similarity of images with provided texts using CLIP model
- [LLaVaCaptioningFilter](../DPF/filters/images/llava_captioning_filter.py) - captioning images using LLaVA models
- [LLaVa34bCaptioningFilter](../DPF/filters/images/llava34b_captioning_filter.py) - captioning images using LLaVA models, llava-v1.6-34b-hf
- [NSFWFilter](../DPF/filters/images/nsfw_filter.py) - NSFW images detection
- [CRAFTFilter](../DPF/filters/images/text_detection_filter.py) - text detection on image
- [OCRFilter](../DPF/filters/images/ocr_filter.py) - text recognition
Expand Down
111 changes: 111 additions & 0 deletions examples/image_filters_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,117 @@
"processor.df['caption liuhaotian/llava-v1.5-13b prompt pixart']\n"
]
},
{
"cell_type": "markdown",
"id": "9e4f15f0",
"metadata": {},
"source": [
"## LLaVa34bCaptioningFilter"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7c629325",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jovyan/.mlspace/envs/env3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"100%|██████████| 1/1 [00:00<00:00, 17.88it/s]\n"
]
}
],
"source": [
"import sys\n",
"sys.path.append('../')\n",
"from DPF import ShardsDatasetConfig, DatasetReader\n",
"\n",
"config = ShardsDatasetConfig.from_path_and_columns(\n",
" 'example_dataset',\n",
" image_name_col='image_name',\n",
" text_col=\"text\"\n",
")\n",
"\n",
"reader = DatasetReader()\n",
"processor = reader.read_from_config(config)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f5cfd34d",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2024-05-12 12:42:04,347] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n",
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour\n",
"Loading checkpoint shards: 100%|██████████| 15/15 [00:31<00:00, 2.10s/it]\n",
"100%|██████████| 250/250 [1:30:04<00:00, 21.62s/it]\n"
]
}
],
"source": [
"from DPF.filters.images.llava34b_captioning_filter import Llava34b_Filter\n",
"\n",
"datafilter = Llava34b_Filter(\n",
" workers=1, \n",
" batch_size=4, \n",
" device='cuda:0',\n",
" crop_size_x = 336,\n",
" crop_size_y = 336\n",
")\n",
"\n",
"processor.apply_data_filter(datafilter)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e471161b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 The image depicts an older couple in a kitchen...\n",
"1 The image shows a close-up of two bowls of gre...\n",
"2 The image shows a golden retriever dog swimmin...\n",
"3 The image depicts a tranquil scene at what app...\n",
"4 The image depicts a serene landscape featuring...\n",
" ... \n",
"995 The image depicts an aerial view of a densely ...\n",
"996 The image depicts an impressionist painting of...\n",
"997 The image depicts a modern and stylish interio...\n",
"998 The image depicts a stylized, fantasy-themed l...\n",
"999 The image shows a pair of dark blue trousers w...\n",
"Name: fix all caption liuhaotian/llava-v1.6-34b prompt short, Length: 1000, dtype: object"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"processor.df['llava-v1.6-34b']"
]
},
{
"cell_type": "markdown",
"id": "76d24a11",
Expand Down