Skip to content

Commit

Permalink
Merge branch 'main' into toggle_kv_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Oct 20, 2024
2 parents 84e8cc5 + 3ca0d30 commit 7906807
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 13 deletions.
Binary file added tests/assets/rgb_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions tests/assets/vqa_tiny.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[
{
"input": "What is presented on image?",
"output": "PyTorch logo.",
"image": "tests/assets/rgb_pytorch.png"
}
]
49 changes: 49 additions & 0 deletions tests/torchtune/datasets/multimodal/test_vqa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from PIL.PngImagePlugin import PngImageFile
from tests.common import ASSETS
from tests.test_utils import DummyTokenizer

from torchtune.datasets.multimodal import vqa_dataset


class TestMultimodalInstructDataset:
@pytest.fixture
def tokenizer(self):
return DummyTokenizer()

def test_get_item(self, tokenizer):
system_prompt = "follow this prompt"

dataset = vqa_dataset(
model_transform=tokenizer,
source="json",
data_files=str(ASSETS / "vqa_tiny.json"),
split="train",
new_system_prompt=system_prompt,
)

expected_tokens = [
[0, 6, 4, 6, -2, 4, 2, 9, 2, 6, 7, 5, -1],
]

expected_labels = [
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 7, 5, -1]
]

assert len(dataset) == 1

for i in range(len(dataset)):
prompt, label, image = (
dataset[i]["tokens"],
dataset[i]["labels"],
dataset[i]["images"],
)
assert prompt == expected_tokens[i]
assert label == expected_labels[i]
assert isinstance(image[0], PngImageFile)
131 changes: 131 additions & 0 deletions tests/torchtune/training/test_activation_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import gpu_test
from torch import nn
from torchtune.training import OffloadActivations


@gpu_test(gpu_count=1)
@pytest.mark.parametrize("use_streams", [True, False])
def test_offloading_is_same_as_without(use_streams) -> None:
with torch.device("cuda"):
torch.manual_seed(2024)
model = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.ReLU(),
)
torch.manual_seed(2024)
model_c = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.ReLU(),
)

inp = torch.randn((2, 10), device="cuda")
loss = model(inp).sum()
loss.backward()

with OffloadActivations(use_streams=use_streams):
loss_c = model_c(inp).sum()
loss_c.backward()

for param, param_c in zip(model.parameters(), model_c.parameters()):
assert torch.equal(param.grad, param_c.grad)


@gpu_test(gpu_count=1)
def test_offloading_works_with_view_outputs() -> None:
"""
This test is quite contrived but tests against a very obscure situation where
any of the outputs of a backward node are a view of the unpacked tensor.
We want to ensure that if an unpacked tensor may be used later that we do not
free it too early.
How did we contrive this test? We need the backward to execute as so:
1. We first need a node that unpacks a tensor and returns a view of the tensor
2. The next node just needs to pass that view along--this NoOp node is needed
to bypass our heuristic where we delete the _previous_ node's stash after
executing the current node.
3. We need to allow the tensor to die to be contaminated with new info, and
we need a way to look into the contents of the contaminated tensor. We
separate these into two nodes (because having them in the same node does
not properly let the tensor reference die as it is within scope.) The
"Compute" Node queues up ~1 second of work on CUDA followed by a kernel
evaluating whether dX is full of 1s. The next Node then inspects the
earlier activation and asserts the result of dX == 1, which is a sync!
Note that for the backward to execute in the above order, the fwd was made
to execute in reverse order.
"""

class BwdReturnsViewOfActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, cloned_activation):
cloned_activation = cloned_activation.t()
ctx.save_for_backward(cloned_activation)
return torch.rand(2, 4, device="cuda")

@staticmethod
def backward(ctx, dy):
unpacked_activation = ctx.saved_tensors[0]
return unpacked_activation.t()

class NoOp(torch.autograd.Function):
@staticmethod
def forward(ctx, cloned_activation):
ctx.save_for_backward(cloned_activation)
return cloned_activation.clone()

@staticmethod
def backward(ctx, viewed_activation):
rando_activation = ctx.saved_tensors[0]
return viewed_activation

class ComputeNode(torch.autograd.Function):
@staticmethod
def forward(ctx, activation):
return activation.clone()

@staticmethod
def backward(ctx, viewed_activation):
torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles
return viewed_activation == 1

class InspectEarlierActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, activation):
ctx.save_for_backward(torch.ones_like(activation) * 5)
return activation

@staticmethod
def backward(ctx, viewed_activation_all_1):
corrupter = ctx.saved_tensors[0]
assert torch.all(
viewed_activation_all_1
) # is the same as before (1s) and NOT W (5s)!!
return corrupter

def fwd(t):
a = InspectEarlierActivation.apply(t)
b = ComputeNode.apply(a)
c = NoOp.apply(b)
d = BwdReturnsViewOfActivation.apply(c)
return d.sum()

tensor_c = torch.ones(256, 1024, device="cuda", requires_grad=True)
ctx = OffloadActivations(use_streams=True)
with ctx:
loss_c = fwd(tensor_c)
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
ctx.fwd_stash = {}
loss_c.backward()
52 changes: 43 additions & 9 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ class InputOutputToMessages(Transform):
keeping the default "input" and "output" column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.
image_dir (Optional[Path]): path to the directory containing the images that is prepended to all image
paths in the dataset. For example, if ``image_dir="/home/user/dataset/"` and the sample image path
was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``.
If None, assume images are available in current working directory or are located
on a remote url. For text-only, leave as None. Default is None.
Raises:
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
Expand All @@ -172,33 +177,62 @@ def __init__(
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
image_dir: Optional[Path] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "input" not in column_map:

self.column_map = column_map

if self.column_map is not None:
if "input" not in self.column_map:
raise ValueError(
f"Expected a key of 'input' in column_map but found {column_map.keys()}."
f"Expected a key of 'input' in column_map but found {self.column_map.keys()}."
)
if "output" not in column_map:
if "output" not in self.column_map:
raise ValueError(
f"Expected a key of 'output' in column_map but found {column_map.keys()}."
f"Expected a key of 'output' in column_map but found {self.column_map.keys()}."
)
self._column_map = column_map
else:
self._column_map = {"input": "input", "output": "output"}
self.column_map = {"input": "input", "output": "output", "image": "image"}

self.image_dir = image_dir

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
is_multimodal = "image" in sample or (
"image" in self.column_map and self.column_map["image"] in sample
)

if is_multimodal:
image_path = sample[self.column_map["image"]]
if isinstance(image_path, str):
if self.image_dir is not None:
image_path = self.image_dir / image_path
# Load if not loaded
pil_image = load_image(image_path)
else:
pil_image = image_path
content = [
{"type": "image", "content": pil_image},
{"type": "text", "content": sample[self.column_map["input"]]},
]
else:
content = [{"type": "text", "content": sample[self.column_map["input"]]}]

output_content = [
{"type": "text", "content": sample[self.column_map["output"]]}
]

messages = [
Message(
role="user",
content=sample[self._column_map["input"]],
content=content,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=sample[self._column_map["output"]],
content=output_content,
masked=False,
eot=True,
),
Expand Down
2 changes: 2 additions & 0 deletions torchtune/datasets/multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from ._llava_instruct import llava_instruct_dataset
from ._multimodal import multimodal_chat_dataset
from ._the_cauldron import the_cauldron_dataset
from ._vqa import vqa_dataset

__all__ = [
"the_cauldron_dataset",
"llava_instruct_dataset",
"multimodal_chat_dataset",
"vqa_dataset",
]
Loading

0 comments on commit 7906807

Please sign in to comment.