diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index f5bd03d6cc..2d9079bda5 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +from contextlib import nullcontext from copy import deepcopy import pytest @@ -329,6 +330,7 @@ def test_optimal_checkpoint_policy( @pytest.mark.skipif(torch.__version__ < "2.3", reason="Only new PyTorch supported") @cuda_only +@pytest.mark.parametrize("no_grad", [False, True]) @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("memory_budget", [0, 0.1, 0.3, 1.0]) @pytest.mark.parametrize("inplace", [False]) @@ -336,7 +338,9 @@ def test_optimal_checkpoint_policy( @torch._dynamo.config.patch( # type: ignore "_experimental_support_context_fn_in_torch_utils_checkpoint", True ) -def test_selective_checkpoint_wrapper_compile(device, memory_budget, inplace, random): +def test_selective_checkpoint_wrapper_compile( + device, no_grad, memory_budget, inplace, random +): torch.manual_seed(42) dtype = torch.float16 modules = _get_model_blocks( @@ -352,18 +356,26 @@ def test_selective_checkpoint_wrapper_compile(device, memory_budget, inplace, ra grad = torch.rand_like(inputs) - torch.manual_seed(42) - out = model(inputs.clone()) - out.backward(grad) + context = torch.no_grad() if no_grad else nullcontext() - torch.manual_seed(42) - out_ref = model_ref(inputs.clone()) - out_ref.backward(grad) + with context: + torch.manual_seed(42) + out = model(inputs.clone()) + if not no_grad: + out.backward(grad) + + torch.manual_seed(42) + out_ref = model_ref(inputs.clone()) + if not no_grad: + out_ref.backward(grad) atol = 3e-4 rtol = 1e-3 torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) + if no_grad: + return + for p, p_ref in zip(model.parameters(), model_ref.parameters()): atol = 4e-4 rtol = 2e-3 diff --git a/xformers/checkpoint.py b/xformers/checkpoint.py index 0ca56379db..4cfa7494b4 100644 --- a/xformers/checkpoint.py +++ b/xformers/checkpoint.py @@ -7,7 +7,6 @@ import functools import time from collections import defaultdict -from contextlib import nullcontext from copy import deepcopy from dataclasses import astuple, dataclass from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple @@ -124,6 +123,13 @@ def pop_from_storage(self, func, args, kwargs): return func(*args, **kwargs) +class NullTorchDispatchMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return func(*args, **kwargs) + + def selective_checkpoint_context_fn(policy_fn=None): """An activation checkpoint context_fn for selectively deciding what to store and what to recompute. Accepts a custom policy. @@ -149,7 +155,7 @@ def selective_checkpoint_context_fn(policy_fn=None): if torch.is_grad_enabled(): caching_mode = _CachingTorchDispatchMode(deepcopy(policy_fn), temp_storage) else: - caching_mode = nullcontext() + caching_mode = NullTorchDispatchMode() cached_mode = CachedTorchDispatchMode(deepcopy(policy_fn), temp_storage) return caching_mode, cached_mode @@ -467,6 +473,11 @@ def __init__(self, mod, memory_budget=None, policy_fn=None): self.memory_budget = memory_budget self.policy_fn = policy_fn + # TODO: this should be enabled by default in PyTorch + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True + ) + @torch.compiler.disable def _get_policy_fn(self, *args, **kwargs): if not torch.is_grad_enabled():