Skip to content

Commit

Permalink
Add vqa_dataset, update docs (#1820)
Browse files Browse the repository at this point in the history
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
Co-authored-by: krammnic <krammnic@krammnic.krammnic.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
  • Loading branch information
5 people authored Oct 17, 2024
1 parent 7d29c21 commit f8073ed
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 9 deletions.
Binary file added tests/assets/rgb_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions tests/assets/vqa_tiny.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[
{
"input": "What is presented on image?",
"output": "PyTorch logo.",
"image": "tests/assets/rgb_pytorch.png"
}
]
49 changes: 49 additions & 0 deletions tests/torchtune/datasets/multimodal/test_vqa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from PIL.PngImagePlugin import PngImageFile
from tests.common import ASSETS
from tests.test_utils import DummyTokenizer

from torchtune.datasets.multimodal import vqa_dataset


class TestMultimodalInstructDataset:
@pytest.fixture
def tokenizer(self):
return DummyTokenizer()

def test_get_item(self, tokenizer):
system_prompt = "follow this prompt"

dataset = vqa_dataset(
model_transform=tokenizer,
source="json",
data_files=str(ASSETS / "vqa_tiny.json"),
split="train",
new_system_prompt=system_prompt,
)

expected_tokens = [
[0, 6, 4, 6, -2, 4, 2, 9, 2, 6, 7, 5, -1],
]

expected_labels = [
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 7, 5, -1]
]

assert len(dataset) == 1

for i in range(len(dataset)):
prompt, label, image = (
dataset[i]["tokens"],
dataset[i]["labels"],
dataset[i]["images"],
)
assert prompt == expected_tokens[i]
assert label == expected_labels[i]
assert isinstance(image[0], PngImageFile)
52 changes: 43 additions & 9 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ class InputOutputToMessages(Transform):
keeping the default "input" and "output" column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.
image_dir (Optional[Path]): path to the directory containing the images that is prepended to all image
paths in the dataset. For example, if ``image_dir="/home/user/dataset/"` and the sample image path
was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``.
If None, assume images are available in current working directory or are located
on a remote url. For text-only, leave as None. Default is None.
Raises:
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
Expand All @@ -172,33 +177,62 @@ def __init__(
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
image_dir: Optional[Path] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "input" not in column_map:

self.column_map = column_map

if self.column_map is not None:
if "input" not in self.column_map:
raise ValueError(
f"Expected a key of 'input' in column_map but found {column_map.keys()}."
f"Expected a key of 'input' in column_map but found {self.column_map.keys()}."
)
if "output" not in column_map:
if "output" not in self.column_map:
raise ValueError(
f"Expected a key of 'output' in column_map but found {column_map.keys()}."
f"Expected a key of 'output' in column_map but found {self.column_map.keys()}."
)
self._column_map = column_map
else:
self._column_map = {"input": "input", "output": "output"}
self.column_map = {"input": "input", "output": "output", "image": "image"}

self.image_dir = image_dir

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
is_multimodal = "image" in sample or (
"image" in self.column_map and self.column_map["image"] in sample
)

if is_multimodal:
image_path = sample[self.column_map["image"]]
if isinstance(image_path, str):
if self.image_dir is not None:
image_path = self.image_dir / image_path
# Load if not loaded
pil_image = load_image(image_path)
else:
pil_image = image_path
content = [
{"type": "image", "content": pil_image},
{"type": "text", "content": sample[self.column_map["input"]]},
]
else:
content = [{"type": "text", "content": sample[self.column_map["input"]]}]

output_content = [
{"type": "text", "content": sample[self.column_map["output"]]}
]

messages = [
Message(
role="user",
content=sample[self._column_map["input"]],
content=content,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=sample[self._column_map["output"]],
content=output_content,
masked=False,
eot=True,
),
Expand Down
2 changes: 2 additions & 0 deletions torchtune/datasets/multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from ._llava_instruct import llava_instruct_dataset
from ._multimodal import multimodal_chat_dataset
from ._the_cauldron import the_cauldron_dataset
from ._vqa import vqa_dataset

__all__ = [
"the_cauldron_dataset",
"llava_instruct_dataset",
"multimodal_chat_dataset",
"vqa_dataset",
]
138 changes: 138 additions & 0 deletions torchtune/datasets/multimodal/_vqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, Optional

from torchtune.data import InputOutputToMessages
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.transforms import Transform


def vqa_dataset(
model_transform: Transform,
*,
source: str,
image_dir: str = None,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> SFTDataset:
"""
Configure a custom visual question answer dataset with separate columns for user question, image, and model response.
This builder function can be used to configure a custom visual question answer dataset directly from the yaml config
as an alternative to :class:`~torchtune.datasets.SFTDataset`, as it is made to be config friendly.
The dataset should follow this format:
.. code-block:: text
| input | image | output |
|-----------------|-----------------|------------------|
| "user prompt" | images/1.jpg | "model response" |
If your column names are different, you can use the ``column_map`` parameter to change
the expected column names. For example, if your dataset has columns ``"question"``,
``"answer"`` and ``"picture"`` you can use:
column_map = {"input": "question", "output": "answer", "image": "picture"}
Args:
model_transform (Transform): callable that applies model-specific pre-processing to the sample.
This includes tokenization and any modality-specific transforms. It is expected to return at
minimum ``"tokens"`` and ``"mask"`` keys.
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text"), pass
in the filepath in ``data_files``, and set ``split="train"``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details.
image_dir (str): path to the directory containing the images that is prepended to all image
paths in the dataset. For example, if ``image_dir="/home/user/dataset/"` and the sample image path
was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``.
If None, assume images are available in current working directory or are located
on a remote url. For text-only, leave as None. Default is None.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input",
"output", and "image" column names to the actual column names in the dataset. Keys should be "input",
"output", and "image, and values should be the actual column names.
Default is None, keeping the default "input" and "output", and "image" column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
such as ``data_files`` or ``split``.
Examples:
::
my_dataset.json
[
{
"question": "What is presented on the image?",
"answer": "PyTorch logo.",
"picture": "rgb_pytorch.png"
},
{
...
},
...,
]
::
>>> from torchtune.datasets.multimodal import vqa_dataset
>>> dataset = vqa_dataset(
... model_transform=model_transform,
... source="json",
... data_files="my_dataset.json",
... column_map={
... "input": "question",
... "output": "answer",
... "image": "picture"
... },
... split="train",
... )
>>> tokens = dataset[0]["tokens"]
>>> model_transform.decode(tokens)
"What is presented on the image?PyTorch logo."
This can also be accomplished via the yaml config:
.. code-block:: yaml
dataset:
_component_: torchtune.datasets.multimodal.vqa_dataset
source: json
data_files: my_dataset.json
column_map:
input: question
output: answer
image: picture
split: train
Returns:
SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset`
"""
message_transform = InputOutputToMessages(
column_map=column_map, new_system_prompt=new_system_prompt, image_dir=image_dir
)

ds = SFTDataset(
source=source,
message_transform=message_transform,
model_transform=model_transform,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)
return ds

0 comments on commit f8073ed

Please sign in to comment.