Skip to content

Commit

Permalink
[dynamo] support torch.nn.attention.sdpa_kernel context manager (pyto…
Browse files Browse the repository at this point in the history
  • Loading branch information
williamwen42 authored and pytorchmergebot committed Sep 12, 2024
1 parent 3de9e47 commit 63d6cd3
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 26 deletions.
111 changes: 111 additions & 0 deletions test/dynamo/test_ctx_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,117 @@ def fn(x):
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
self.assertEqual(cnts.frame_count, 2)

def test_sdpa_kernel_ctx_manager1(self):
modified_backend_state = [torch.nn.attention.SDPBackend.MATH]

@torch._dynamo.allow_in_graph
def check_backend_state_is_modified():
self.assertEqual(
torch.nn.attention._cur_sdpa_kernel_backends(), modified_backend_state
)

def f(x):
with torch.nn.attention.sdpa_kernel(
# pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
[torch.nn.attention.SDPBackend.MATH]
):
output = torch.nn.functional.scaled_dot_product_attention(x, x, x).to(
torch.float32
)
check_backend_state_is_modified()

return output

opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16))

def test_sdpa_kernel_ctx_manager2(self):
original_backend_state = set(torch.nn.attention._cur_sdpa_kernel_backends())
modified_backend_state = [torch.nn.attention.SDPBackend.MATH]

@torch._dynamo.allow_in_graph
def check_backend_state_is_original():
self.assertEqual(
set(torch.nn.attention._cur_sdpa_kernel_backends()),
original_backend_state,
)

@torch._dynamo.allow_in_graph
def check_backend_state_is_modified():
self.assertEqual(
torch.nn.attention._cur_sdpa_kernel_backends(), modified_backend_state
)

def g(x):
torch._dynamo.graph_break()
output = torch.nn.functional.scaled_dot_product_attention(x, x, x).to(
torch.float32
)
check_backend_state_is_modified()
return output

def f(x):
check_backend_state_is_original()
with torch.nn.attention.sdpa_kernel(
# pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
[torch.nn.attention.SDPBackend.MATH]
):
output1 = torch.nn.functional.scaled_dot_product_attention(x, x, x).to(
torch.float32
)
check_backend_state_is_modified()

# graph break
output2 = g(x)

output3 = torch.nn.functional.scaled_dot_product_attention(x, x, x).to(
torch.float32
)
check_backend_state_is_modified()

check_backend_state_is_original()

return output1 + output2 + output3

cnts = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnts)
opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16))
self.assertEqual(cnts.frame_count, 3)

# test sdpa_kernel graph break with 2 arguments
def test_sdpa_kernel_ctx_manager3(self):
modified_backend_state = {
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
}

@torch._dynamo.allow_in_graph
def check_backend_state_is_modified():
self.assertEqual(
set(torch.nn.attention._cur_sdpa_kernel_backends()),
modified_backend_state,
)

def f(x):
with torch.nn.attention.sdpa_kernel(
# pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
[
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
]
):
# FLASH_ATTENTION may not be supported, but we're not actually
# doing any sdpa
x = x + 1
torch._dynamo.graph_break()
check_backend_state_is_modified()
x = x + 1

return x

opt_f = torch.compile(f, backend="eager")
opt_f(torch.randn(2, 2))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
GradModeVariable,
InferenceModeVariable,
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,9 @@ def build_key_value(i, k, v):
elif isinstance(value, (torch._C._SDPAParams)):
self.install_guards(GuardBuilder.TYPE_MATCH)
return SDPAParamsVariable.create(self.tx, value, self.source)
elif isinstance(value, torch._C._SDPBackend):
self.install_guards(GuardBuilder.ID_MATCH)
return ConstantVariable(value)
elif isinstance(value, _EventBase):
self.install_guards(GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value)
Expand Down
74 changes: 74 additions & 0 deletions torch/_dynamo/variables/ctx_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,80 @@ def fn_name(self):
return "use_training_state"


class SDPAKernelVariable(ContextWrappingVariable):
"""represents torch.nn.attention.sdpa_kernel"""

@staticmethod
def create(tx: "InstructionTranslator", backends, **kwargs):
if isinstance(backends, torch.nn.attention.SDPBackend):
backends = [backends]
var = SDPAKernelVariable(
target_values=backends,
initial_values=None,
**kwargs,
)
return var

def __init__(
self,
target_values: List[torch.nn.attention.SDPBackend],
initial_values=None,
**kwargs,
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)

@staticmethod
def _backends_to_nodes(tx, backends):
nodes = []
for backend in backends:
# convert to/from string in order to bake the backend into FX graph
nodes.append(
tx.output.create_node(
"call_function",
torch.nn.attention._backend_from_string,
(backend.name,),
{},
)
)
return nodes

def enter(self, tx):
self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends()
self.set_cleanup_hook(
tx, lambda: torch.nn.attention._sdpa_kernel(self.prev_backends)
)
torch.nn.attention._sdpa_kernel(self.target_values)
arg = self._backends_to_nodes(tx, self.target_values)
tx.output.create_node(
"call_function",
torch.nn.attention._sdpa_kernel,
(arg,),
{},
)
return variables.ConstantVariable.create(None)

def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
arg = self._backends_to_nodes(tx, self.prev_backends)
tx.output.create_node(
"call_function",
torch.nn.attention._sdpa_kernel,
(arg,),
{},
)
return variables.ConstantVariable.create(None)

def module_name(self):
return "torch.nn.attention"

# use a private version of sdpa_kernel that accepts variadic arguments
# since dynamo reconstructs the contents of target_values one-by-one
def fn_name(self):
return "_sdpa_kernel_variadic"


class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
Expand Down
11 changes: 11 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
torch.autograd.graph.disable_saved_tensors_hooks,
torch.cpu.amp.autocast_mode.autocast,
torch.cuda.amp.autocast_mode.autocast,
torch.nn.attention.sdpa_kernel,
torch.nn.attention._sdpa_kernel_variadic,
]
)

Expand Down Expand Up @@ -229,6 +231,7 @@ def call_function(
GradModeVariable,
InferenceModeVariable,
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFwdGradEnabledContextManager,
StreamVariable,
VmapIncrementNestingCtxManagerVariable,
Expand Down Expand Up @@ -329,6 +332,14 @@ def call_function(
return FSDPParamGroupUseTrainingStateVariable.create(
tx, args[0], args[1].as_python_constant()
)
elif self.value is torch.nn.attention.sdpa_kernel:
assert len(args) == 1 or (len(kwargs) == 1 and "backends" in kwargs)
backends = args[0] if len(args) == 1 else kwargs["backends"]
return SDPAKernelVariable.create(tx, backends.as_python_constant())
elif self.value is torch.nn.attention._sdpa_kernel_variadic:
return SDPAKernelVariable.create(
tx, [arg.as_python_constant() for arg in args]
)

return super().call_function(tx, args, kwargs)

Expand Down
64 changes: 38 additions & 26 deletions torch/nn/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
# mypy: allow-untyped-defs
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
import contextlib
from typing import List, Union
from typing import Iterable, List, Union
from warnings import warn

import torch.backends.cuda
from torch._C import _SDPBackend as SDPBackend
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
cudnn_sdp_enabled,
enable_cudnn_sdp,
enable_flash_sdp,
enable_math_sdp,
enable_mem_efficient_sdp,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)

Expand Down Expand Up @@ -67,6 +60,32 @@ def _raise_kernel_warnings(params: SDPAParams) -> None:
can_use_flash_attention(params, True)


_backend_names = {
"cudnn": "CUDNN_ATTENTION",
"flash": "FLASH_ATTENTION",
"mem_efficient": "EFFICIENT_ATTENTION",
"math": "MATH",
}


def _backend_from_string(name: str):
return getattr(SDPBackend, name)


def _cur_sdpa_kernel_backends():
backends: List[SDPBackend] = []
for name, val in _backend_names.items():
if getattr(torch.backends.cuda, f"{name}_sdp_enabled")():
backends.append(getattr(SDPBackend, val))
return backends


def _sdpa_kernel(backends: Iterable[SDPBackend]):
for name, val in _backend_names.items():
enabled = getattr(SDPBackend, val) in backends
getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled)


@contextlib.contextmanager
def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
r"""
Expand Down Expand Up @@ -102,26 +121,19 @@ def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
backends = [backends]

backends = set(backends)
previous_cudnn: bool = cudnn_sdp_enabled()
previous_flash: bool = flash_sdp_enabled()
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
previous_math: bool = math_sdp_enabled()
previous_backends = _cur_sdpa_kernel_backends()
try:
enable_cudnn = SDPBackend.CUDNN_ATTENTION in backends
enable_flash = SDPBackend.FLASH_ATTENTION in backends
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
enable_math = SDPBackend.MATH in backends

enable_cudnn_sdp(enable_cudnn)
enable_flash_sdp(enable_flash)
enable_mem_efficient_sdp(enable_mem_efficient)
enable_math_sdp(enable_math)
_sdpa_kernel(backends)
yield {}
finally:
enable_cudnn_sdp(previous_cudnn)
enable_flash_sdp(previous_flash)
enable_mem_efficient_sdp(previous_mem_efficient)
enable_math_sdp(previous_math)
_sdpa_kernel(previous_backends)


# variadic version of sdpa_kernel for dynamo to use while reconstructing
@contextlib.contextmanager
def _sdpa_kernel_variadic(*backends: SDPBackend):
with sdpa_kernel(list(backends)):
yield


def _get_flash_version() -> str:
Expand Down

0 comments on commit 63d6cd3

Please sign in to comment.