Skip to content

Commit

Permalink
Make checkpoint work with compile in no_grad mode (fairinternal/xform…
Browse files Browse the repository at this point in the history
…ers#1049)

* Make checkpoint work with compile in no_grad mode

* Enable context_fn in compile

* Add test

__original_commit__ = fairinternal/xformers@8a3f7b8
  • Loading branch information
fmassa authored and xFormers Bot committed Mar 11, 2024
1 parent 8c7d37f commit f4e487c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
26 changes: 19 additions & 7 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


from contextlib import nullcontext
from copy import deepcopy

import pytest
Expand Down Expand Up @@ -329,14 +330,17 @@ def test_optimal_checkpoint_policy(

@pytest.mark.skipif(torch.__version__ < "2.3", reason="Only new PyTorch supported")
@cuda_only
@pytest.mark.parametrize("no_grad", [False, True])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("memory_budget", [0, 0.1, 0.3, 1.0])
@pytest.mark.parametrize("inplace", [False])
@pytest.mark.parametrize("random", [False])
@torch._dynamo.config.patch( # type: ignore
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_selective_checkpoint_wrapper_compile(device, memory_budget, inplace, random):
def test_selective_checkpoint_wrapper_compile(
device, no_grad, memory_budget, inplace, random
):
torch.manual_seed(42)
dtype = torch.float16
modules = _get_model_blocks(
Expand All @@ -352,18 +356,26 @@ def test_selective_checkpoint_wrapper_compile(device, memory_budget, inplace, ra

grad = torch.rand_like(inputs)

torch.manual_seed(42)
out = model(inputs.clone())
out.backward(grad)
context = torch.no_grad() if no_grad else nullcontext()

torch.manual_seed(42)
out_ref = model_ref(inputs.clone())
out_ref.backward(grad)
with context:
torch.manual_seed(42)
out = model(inputs.clone())
if not no_grad:
out.backward(grad)

torch.manual_seed(42)
out_ref = model_ref(inputs.clone())
if not no_grad:
out_ref.backward(grad)

atol = 3e-4
rtol = 1e-3
torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)

if no_grad:
return

for p, p_ref in zip(model.parameters(), model_ref.parameters()):
atol = 4e-4
rtol = 2e-3
Expand Down
15 changes: 13 additions & 2 deletions xformers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import functools
import time
from collections import defaultdict
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import astuple, dataclass
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -124,6 +123,13 @@ def pop_from_storage(self, func, args, kwargs):
return func(*args, **kwargs)


class NullTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)


def selective_checkpoint_context_fn(policy_fn=None):
"""An activation checkpoint context_fn for selectively deciding what to
store and what to recompute. Accepts a custom policy.
Expand All @@ -149,7 +155,7 @@ def selective_checkpoint_context_fn(policy_fn=None):
if torch.is_grad_enabled():
caching_mode = _CachingTorchDispatchMode(deepcopy(policy_fn), temp_storage)
else:
caching_mode = nullcontext()
caching_mode = NullTorchDispatchMode()
cached_mode = CachedTorchDispatchMode(deepcopy(policy_fn), temp_storage)

return caching_mode, cached_mode
Expand Down Expand Up @@ -467,6 +473,11 @@ def __init__(self, mod, memory_budget=None, policy_fn=None):
self.memory_budget = memory_budget
self.policy_fn = policy_fn

# TODO: this should be enabled by default in PyTorch
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)

@torch.compiler.disable
def _get_policy_fn(self, *args, **kwargs):
if not torch.is_grad_enabled():
Expand Down

0 comments on commit f4e487c

Please sign in to comment.