Skip to content

Commit 596e50b

Browse files
authored
Revert flashmask bidirectional attention optimization (#67811)
* Revert flashmask bidirectional attention optimization * update * fix unitest
1 parent bf822fa commit 596e50b

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

python/paddle/nn/functional/flash_attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,9 @@ def flashmask_attention(
14751475
has_end = False
14761476
elif startend_row_indices.shape[-1] == 4:
14771477
has_end = True
1478+
raise NotImplementedError(
1479+
"ending row index is not implemented yet."
1480+
)
14781481
else:
14791482
raise ValueError(
14801483
f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}"

test/legacy_test/test_flash_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def setUp(self):
477477
self.enable_mem_efficient = False
478478

479479

480-
class TestFlashAttenionWithMaskAPITest(TestFlashAttentionWithMaskAPI):
480+
class TestFlashAttentionWithMaskAPITest(TestFlashAttentionWithMaskAPI):
481481
def setUp(self):
482482
self.place = paddle.CUDAPlace(0)
483483
self.shape = (8, 1024, 16, 128)
@@ -869,7 +869,6 @@ def generate_mask_matrix_from_mask_indices(start_rows):
869869
for j in range(seq_len):
870870
start_row = start_rows[bz_idx, head_idx, j]
871871
matrix[bz_idx, head_idx, start_row:, j] = -np.inf
872-
matrix[bz_idx, head_idx, j, j] = 0.0
873872
return matrix
874873

875874

@@ -945,7 +944,7 @@ def test_dot_scale_product(self):
945944
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
946945

947946

948-
class TestFlashAttenionWithSparseMaskAPITest(
947+
class TestFlashAttentionWithSparseMaskAPITest(
949948
TestFlashAttentionWithSparseMaskAPI
950949
):
951950
def setUp(self):
@@ -956,7 +955,7 @@ def setUp(self):
956955
self.causal = True
957956

958957

959-
class TestFlashAttenionWithSparseMaskBF16APITest(
958+
class TestFlashAttentionWithSparseMaskBF16APITest(
960959
TestFlashAttentionWithSparseMaskAPI
961960
):
962961
def setUp(self):

0 commit comments

Comments
 (0)