From 3ca0d309c67ea996cc69f29691bc97ad7de00819 Mon Sep 17 00:00:00 2001 From: "Jane (Yuan) Xu" <31798555+janeyx99@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:00:20 -0400 Subject: [PATCH] Add offloading tests and fix obscure edge case (#1860) --- .../training/test_activation_offloading.py | 131 ++++++++++++++++++ torchtune/training/_activation_offloading.py | 26 +++- 2 files changed, 153 insertions(+), 4 deletions(-) create mode 100644 tests/torchtune/training/test_activation_offloading.py diff --git a/tests/torchtune/training/test_activation_offloading.py b/tests/torchtune/training/test_activation_offloading.py new file mode 100644 index 0000000000..5d4c968e96 --- /dev/null +++ b/tests/torchtune/training/test_activation_offloading.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import gpu_test +from torch import nn +from torchtune.training import OffloadActivations + + +@gpu_test(gpu_count=1) +@pytest.mark.parametrize("use_streams", [True, False]) +def test_offloading_is_same_as_without(use_streams) -> None: + with torch.device("cuda"): + torch.manual_seed(2024) + model = nn.Sequential( + nn.Linear(10, 10), + nn.Linear(10, 10), + nn.Linear(10, 10), + nn.ReLU(), + ) + torch.manual_seed(2024) + model_c = nn.Sequential( + nn.Linear(10, 10), + nn.Linear(10, 10), + nn.Linear(10, 10), + nn.ReLU(), + ) + + inp = torch.randn((2, 10), device="cuda") + loss = model(inp).sum() + loss.backward() + + with OffloadActivations(use_streams=use_streams): + loss_c = model_c(inp).sum() + loss_c.backward() + + for param, param_c in zip(model.parameters(), model_c.parameters()): + assert torch.equal(param.grad, param_c.grad) + + +@gpu_test(gpu_count=1) +def test_offloading_works_with_view_outputs() -> None: + """ + This test is quite contrived but tests against a very obscure situation where + any of the outputs of a backward node are a view of the unpacked tensor. + + We want to ensure that if an unpacked tensor may be used later that we do not + free it too early. + + How did we contrive this test? We need the backward to execute as so: + 1. We first need a node that unpacks a tensor and returns a view of the tensor + 2. The next node just needs to pass that view along--this NoOp node is needed + to bypass our heuristic where we delete the _previous_ node's stash after + executing the current node. + 3. We need to allow the tensor to die to be contaminated with new info, and + we need a way to look into the contents of the contaminated tensor. We + separate these into two nodes (because having them in the same node does + not properly let the tensor reference die as it is within scope.) The + "Compute" Node queues up ~1 second of work on CUDA followed by a kernel + evaluating whether dX is full of 1s. The next Node then inspects the + earlier activation and asserts the result of dX == 1, which is a sync! + + Note that for the backward to execute in the above order, the fwd was made + to execute in reverse order. + """ + + class BwdReturnsViewOfActivation(torch.autograd.Function): + @staticmethod + def forward(ctx, cloned_activation): + cloned_activation = cloned_activation.t() + ctx.save_for_backward(cloned_activation) + return torch.rand(2, 4, device="cuda") + + @staticmethod + def backward(ctx, dy): + unpacked_activation = ctx.saved_tensors[0] + return unpacked_activation.t() + + class NoOp(torch.autograd.Function): + @staticmethod + def forward(ctx, cloned_activation): + ctx.save_for_backward(cloned_activation) + return cloned_activation.clone() + + @staticmethod + def backward(ctx, viewed_activation): + rando_activation = ctx.saved_tensors[0] + return viewed_activation + + class ComputeNode(torch.autograd.Function): + @staticmethod + def forward(ctx, activation): + return activation.clone() + + @staticmethod + def backward(ctx, viewed_activation): + torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles + return viewed_activation == 1 + + class InspectEarlierActivation(torch.autograd.Function): + @staticmethod + def forward(ctx, activation): + ctx.save_for_backward(torch.ones_like(activation) * 5) + return activation + + @staticmethod + def backward(ctx, viewed_activation_all_1): + corrupter = ctx.saved_tensors[0] + assert torch.all( + viewed_activation_all_1 + ) # is the same as before (1s) and NOT W (5s)!! + return corrupter + + def fwd(t): + a = InspectEarlierActivation.apply(t) + b = ComputeNode.apply(a) + c = NoOp.apply(b) + d = BwdReturnsViewOfActivation.apply(c) + return d.sum() + + tensor_c = torch.ones(256, 1024, device="cuda", requires_grad=True) + ctx = OffloadActivations(use_streams=True) + with ctx: + loss_c = fwd(tensor_c) + # delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd + ctx.fwd_stash = {} + loss_c.backward() diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index 5156281aa8..c536e7f5ee 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -146,8 +146,12 @@ def pack_tensor(activation: torch.Tensor) -> int: num_bytes = get_num_bytes_tensor(activation) tensor_id = get_tensor_id() - # only offload hefty bois - if num_bytes >= self.min_tensor_size_bytes: + # only offload hefty bois if they're activations (our heuristic for that is to + # check if they're not params or buffers)! + if num_bytes >= self.min_tensor_size_bytes and ( + not isinstance(activation, torch.nn.Parameter) + and not isinstance(activation, torch.nn.Buffer) + ): if self.use_streams: # First, sync back and dereference previously offloaded tensors # as the offloading should be done sufficiently long ago. @@ -281,8 +285,22 @@ def wait_and_del_remaining_references() -> None: def hook(outputs, inputs): # create events for the current node inputs/outputs if they were streamed in if brought_back_from_cpu: - event = self.s0.record_event() - self.bwd_ev_stash[unpack_tensor_id] = event + # if any of the outputs is a view of the tensor, meaning the tensor might be used later, + # we cannot presume to delete it after only the current node is done! So we use our frenemy, + # record_stream, to ensure the Tensor stays unmessed with until it's done getting used + # in the compute stream (s0 here). Note that the con here is we introduce non-deterministic + # memory usage, but this case should not happen often. + unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] + if any( + o.untyped_storage() is unpacked_tensor.untyped_storage() + for o in outputs + if o is not None + ): + unpacked_tensor.record_stream(self.s0) + del self.bwd_tensor_stash[unpack_tensor_id] + else: + event = self.s0.record_event() + self.bwd_ev_stash[unpack_tensor_id] = event # if there are still things in the fwd_stash, get rid of them as we're in bwd now for id in [k for k in self.fwd_stash.keys()]: