Skip to content

Commit

Permalink
[Inductor] Make FlexAttention block_mask argument as tuple (pytorch#1…
Browse files Browse the repository at this point in the history
…29831)

Re-organize ```block_mask``` related arguments a tuple to reduce the individual argument number. I was trying to use named tuple, but aot autograd doesn't work well with named tuple. The only downside of using tuple rather than named tuple is we need to use index to access its element. But we only need this at one place, it should be fine.

Pull Request resolved: pytorch#129831
Approved by: https://github.com/Chillee, https://github.com/drisspg
  • Loading branch information
yanboliang authored and pytorchmergebot committed Jul 2, 2024
1 parent 9105d54 commit 34e94c5
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 324 deletions.
76 changes: 28 additions & 48 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torch.nn.attention._flex_attention import (
_causal,
_compose,
_create_block_sparse_mask,
_create_empty_block_sparse_mask,
_create_block_mask,
_create_empty_block_mask,
_flex_attention,
_generate_alibi_bias,
_identity,
Expand All @@ -46,17 +46,17 @@
index = torch.ops.aten.index


def create_attention(score_mod, block_sparse_mask):
def create_attention(score_mod, block_mask):
return functools.partial(
_flex_attention, score_mod=score_mod, block_sparse_mask=block_sparse_mask
_flex_attention, score_mod=score_mod, block_mask=block_mask
)


def create_block_sparse_mask_from_score_mod(score_mod, query, key, value):
def create_block_mask_from_score_mod(score_mod, query, key, value):
Q_LEN = query.shape[-2]
KV_LEN = key.shape[-2]
if score_mod == _causal:
return _create_block_sparse_mask(
return _create_block_mask(
torch.tril(
torch.ones(Q_LEN, KV_LEN, dtype=torch.bool, device=query.device)
),
Expand Down Expand Up @@ -233,8 +233,8 @@ def run_test(
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
block_sparse_mask = create_block_sparse_mask_from_score_mod(score_mod, q, k, v)
sdpa_partial = create_attention(score_mod, block_sparse_mask)
block_mask = create_block_mask_from_score_mod(score_mod, q, k, v)
sdpa_partial = create_attention(score_mod, block_mask)
compiled_sdpa = torch.compile(sdpa_partial)
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
Expand Down Expand Up @@ -503,9 +503,9 @@ def test_strided_inputs(self, dtype: torch.dtype, q_s, k_s, v_s):
assert v_strides[-1] == 1
v = torch.as_strided(v1, v_shape, v_strides, v_offset)

block_mask = _create_empty_block_sparse_mask(q, k, v)
block_mask = _create_empty_block_mask(q, k, v)
sdpa_partial = create_attention(
score_mod=_generate_alibi_bias(8), block_sparse_mask=block_mask
score_mod=_generate_alibi_bias(8), block_mask=block_mask
)
compiled_sdpa = torch.compile(sdpa_partial)
ref_out = sdpa_partial(q, k, v)
Expand All @@ -517,7 +517,7 @@ def test_strided_inputs(self, dtype: torch.dtype, q_s, k_s, v_s):
)

@supported_platform
def test_create_block_sparse_mask_is_compiled(self):
def test_create_block_mask_is_compiled(self):
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
Expand All @@ -529,7 +529,7 @@ def test_create_block_sparse_mask_is_compiled(self):

@torch.compile
def func(q, k, v):
block_sparse_mask = _create_block_sparse_mask(
block_mask = _create_block_mask(
torch.tril(
torch.ones(
q.shape[-2], k.shape[-2], dtype=torch.bool, device=q.device
Expand All @@ -544,17 +544,17 @@ def func(q, k, v):
k,
v,
_causal,
block_sparse_mask,
block_mask,
)
return out

_, code = run_and_get_code(func, q, k, v)
# Ensure _create_block_sparse_mask is compiled and generates 3 kernels,
# Ensure _create_block_mask is compiled and generates 3 kernels,
# flex_attention generates 1 kernel.
FileCheck().check_count(".run(", 4, True).run(code[0])

@supported_platform
def test_block_sparse_mask_is_reused(self):
def test_block_mask_is_reused(self):
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
Expand All @@ -568,7 +568,7 @@ def test_block_sparse_mask_is_reused(self):

@torch.compile
def func(q, k, v, k2, v2):
block_sparse_mask = _create_block_sparse_mask(
block_mask = _create_block_mask(
torch.tril(
torch.ones(
q.shape[-2], k.shape[-2], dtype=torch.bool, device=q.device
Expand All @@ -583,19 +583,19 @@ def func(q, k, v, k2, v2):
k,
v,
_causal,
block_sparse_mask,
block_mask,
)
out = _flex_attention(
q,
k2,
v2,
_causal,
block_sparse_mask,
block_mask,
)
return out

_, code = run_and_get_code(func, q, k, v, k2, v2)
# Ensure _create_block_sparse_mask is compiled and generates 3 kernels,
# Ensure _create_block_mask is compiled and generates 3 kernels,
# 2 flex_attention generates 2 kernels.
FileCheck().check_count(".run(", 5, True).run(code[0])

Expand Down Expand Up @@ -1013,7 +1013,7 @@ def test_logsumexp_correctness(self, dtype, score_mod):
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
block_mask = _create_empty_block_sparse_mask(q, k, v)
block_mask = _create_empty_block_mask(q, k, v)

@torch.compile
def sdpa_hop(q, k, v, score_mod, block_mask):
Expand All @@ -1022,12 +1022,7 @@ def sdpa_hop(q, k, v, score_mod, block_mask):
k,
v,
score_mod,
block_mask.kv_num_blocks,
block_mask.kv_indices,
block_mask.q_num_blocks,
block_mask.q_indices,
block_mask.KV_BLOCK_SIZE,
block_mask.Q_BLOCK_SIZE,
block_mask,
)

@torch.compile(backend="aot_eager")
Expand All @@ -1041,12 +1036,7 @@ def eager_sdpa_hop(q, k, v, score_mod, block_mask):
k,
v,
score_mod,
block_mask.kv_num_blocks,
block_mask.kv_indices,
block_mask.q_num_blocks,
block_mask.q_indices,
block_mask.KV_BLOCK_SIZE,
block_mask.Q_BLOCK_SIZE,
block_mask,
)

ref_out, ref_lse = eager_sdpa_hop(
Expand Down Expand Up @@ -1094,7 +1084,7 @@ def test_logsumexp_only_return(self):
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
block_mask = _create_empty_block_sparse_mask(q, k, v)
block_mask = _create_empty_block_mask(q, k, v)

@torch.compile
def func(q, k, v, score_mod, block_mask):
Expand All @@ -1103,12 +1093,7 @@ def func(q, k, v, score_mod, block_mask):
k,
v,
score_mod,
block_mask.kv_num_blocks,
block_mask.kv_indices,
block_mask.q_num_blocks,
block_mask.q_indices,
block_mask.KV_BLOCK_SIZE,
block_mask.Q_BLOCK_SIZE,
block_mask,
)
lse_2 = lse * 2
return lse_2
Expand All @@ -1127,7 +1112,7 @@ def test_logsumexp_is_not_fused(self):
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
block_mask = _create_empty_block_sparse_mask(q, k, v)
block_mask = _create_empty_block_mask(q, k, v)

@torch.compile
def func(q, k, v, score_mod, block_mask):
Expand All @@ -1136,12 +1121,7 @@ def func(q, k, v, score_mod, block_mask):
k,
v,
score_mod,
block_mask.kv_num_blocks,
block_mask.kv_indices,
block_mask.q_num_blocks,
block_mask.q_indices,
block_mask.KV_BLOCK_SIZE,
block_mask.Q_BLOCK_SIZE,
block_mask,
)
lse_2 = lse * 2
return out, lse_2
Expand Down Expand Up @@ -1239,7 +1219,7 @@ def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_
new_empty_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
new_empty_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
flex_attention_0 = self.flex_attention_0
flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, ones, zeros, ones_1, zeros_1, 8, 8); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 8, 8)); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
out: "f64[2, 2, 8, 4]" = flex_attention[0]; flex_attention = None
return (out,)
Expand Down Expand Up @@ -1273,7 +1253,7 @@ class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", full_default: "i32[1, 1, 1]", full_default_1: "i32[1, 1, 1, 1]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
fw_graph = self.fw_graph
joint_graph = self.joint_graph
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, full_default, full_default_1, full_default, full_default_1, 8, 8); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 8, 8)); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None
Expand Down
25 changes: 7 additions & 18 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,13 +1482,13 @@ def call_function(
class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
@staticmethod
def normalize_to_args(args, kwargs):
# input signature is (query, key, value, score_mod, *other_buffers)
# Flatten args and kwargs into lists
flat_args = pytree.tree_flatten(args)[0]
# input signature is (query, key, value, score_mod, block_mask, *other_buffers),
# block_mask is a tuple, and we don't want to flatten it.
# only flatten kwargs into lists
flat_kwargs = pytree.tree_flatten(kwargs)[0]

# Combine the flattened lists
all_args = flat_args + flat_kwargs
all_args = args + flat_kwargs
return all_args

def create_wrapped_node(
Expand Down Expand Up @@ -1563,25 +1563,15 @@ def call_function(
key,
value,
score_mod,
sparse_kv_num_blocks,
sparse_kv_indices,
sparse_q_num_blocks,
sparse_q_indices,
SPARSE_KV_BLOCK_SIZE,
SPARSE_Q_BLOCK_SIZE,
block_mask,
) = self.normalize_to_args(args, kwargs)

p_args = self.create_wrapped_node(tx, query, score_mod)
proxied_args = [
query,
key,
value,
sparse_kv_num_blocks,
sparse_kv_indices,
sparse_q_num_blocks,
sparse_q_indices,
SPARSE_KV_BLOCK_SIZE,
SPARSE_Q_BLOCK_SIZE,
block_mask,
]

# Store the invocation as a call
Expand All @@ -1599,8 +1589,7 @@ def call_function(
example_value = (out_meta, lse_meta)

# Compose the ordered HOO args from two parts:
# - inp_args: [query, key, value, sparse_kv_num_blocks, sparse_kv_indices,
# sparse_q_num_blocks, sparse_q_indices, SPARSE_KV_BLOCK_SIZE, SPARSE_Q_BLOCK_SIZE]
# - inp_args: [query, key, value, block_mask]
# - p_args: [score_mod, *other_buffers]
return wrap_fx_proxy(
tx=tx,
Expand Down
Loading

0 comments on commit 34e94c5

Please sign in to comment.