Skip to content

Emulate flash attentnion unittests on CPU. #1021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _repeat_kv_heads(num_q_heads: int, key_or_value: Tensor) -> Tensor:
def flash_attention_implementation(
backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"],
*,
softmax_scale: float,
softmax_scale: float = 1.0,
block_size: int = 128,
dropout_rate: Optional[float] = 0.0,
) -> MultiHeadAttentionImpl:
Expand Down Expand Up @@ -202,7 +202,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
mask_fn=mask_fn,
kv_seq_len=kv_seq_len,
softmax_scale=softmax_scale,
interpret=(backend == "cpu"),
interpret=_interpret(backend),
)

key = _repeat_kv_heads(query.shape[2], key)
Expand Down Expand Up @@ -237,7 +237,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
softmax_scale=softmax_scale,
mask_fn=mask.mask if mask.has_value() else None,
dropout_rate=dropout_rate,
interpret=(backend == "cpu"),
interpret=_interpret(backend),
)
else:
causal, explicit_bias = split(
Expand Down Expand Up @@ -276,7 +276,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
mask=mask,
softmax_scale=softmax_scale,
block_size=block_size,
interpret=(backend == "cpu"),
interpret=_interpret(backend),
)

elif backend == "neuron":
Expand Down Expand Up @@ -332,3 +332,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
raise NotImplementedError(f"Backend ({backend}) does not have an implementation.")

return jit_attn


def _interpret(backend: str):
return backend == "cpu"
221 changes: 221 additions & 0 deletions axlearn/common/flash_attention/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright © 2025 Apple Inc.
"""Tests utils.py.

XLA_FLAGS="--xla_force_host_platform_device_count=8" \
pytest -m "for_8_devices" axlearn/common/flash_attention/utils_test.py -n auto

This test is expected to run on CPU and is designed to validate GPU/TPU code from a CPU environment.
It allows quick verification in CI and local environments to ensure that code changes do not break
GPU/TPU Flash Attention.
"""

from unittest.mock import patch

import jax
import jax.numpy as jnp
import pytest
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.sharding import Mesh

from axlearn.common.attention_bias import (
CausalAttentionBias,
MaskFnAttentionBias,
ZeroAttentionBias,
sliding_window_causal_mask,
)
from axlearn.common.flash_attention import utils
from axlearn.common.test_utils import TestCase, is_supported_mesh_shape


def setUpModule():
# Uncomment for local debugging.
# import chex
# chex.set_n_cpu_devices(8)
if jax.default_backend() in ("gpu", "tpu"):
pytest.skip(reason="This is a CPU only test.", allow_module_level=True)


def _get_inputs(
*,
batch: int,
seq_len: int,
num_heads: int,
num_kv_heads: int,
per_head_dim: int,
input_dtype: jnp.dtype = jnp.bfloat16,
):
query = jax.random.normal(
jax.random.PRNGKey(0),
[batch, seq_len, num_heads, per_head_dim],
dtype=input_dtype,
)
key = jax.random.normal(
jax.random.PRNGKey(1),
[batch, seq_len, num_kv_heads, per_head_dim],
dtype=input_dtype,
)
value = jax.random.normal(
jax.random.PRNGKey(2),
[batch, seq_len, num_kv_heads, per_head_dim],
dtype=input_dtype,
)
return query, key, value


class TestFlashAttention(TestCase):
"""Tests FlashAttention layer."""

_TEST_CONFIGS = [
dict(
batch=8,
seq_len=256,
num_heads=4,
num_kv_heads=None,
per_head_dim=128,
mesh=(1, 1, 8, 1),
mesh_axis_names=("data", "expert", "fsdp", "model"),
),
dict(
batch=8,
seq_len=256,
num_heads=4,
num_kv_heads=1,
per_head_dim=128,
mesh=(2, 1, 2, 2),
mesh_axis_names=("data", "expert", "fsdp", "model"),
),
]

@parameterized.product(
_TEST_CONFIGS,
backend=["cpu", "gpu", "tpu"],
bias_type=["full", "causal", "sliding"],
input_dtype=[jnp.float32],
)
@pytest.mark.for_8_devices
def test_forward(
self,
batch,
seq_len,
num_heads,
num_kv_heads,
per_head_dim,
mesh,
mesh_axis_names,
backend,
bias_type,
input_dtype,
):
if not is_supported_mesh_shape(mesh):
pytest.skip(reason=f"Unsupported mesh {mesh}.")

if bias_type == "full":
bias = ZeroAttentionBias()
elif bias_type == "causal":
bias = CausalAttentionBias(shape=(seq_len, seq_len))
else:
assert bias_type == "sliding"
bias = MaskFnAttentionBias(
mask=sliding_window_causal_mask(sliding_window_size=4), shape=(seq_len, seq_len)
)

with patch("axlearn.common.flash_attention.utils._interpret", return_value=True):
with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names):
xla_fn = utils.flash_attention_implementation("xla")
test_fn = utils.flash_attention_implementation(backend)

query, key, value = _get_inputs(
batch=batch,
seq_len=seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads or num_heads,
per_head_dim=per_head_dim,
input_dtype=input_dtype,
)
prng_key = jax.random.PRNGKey(0)

ref_out = xla_fn(query, key, value, bias, prng_key)
test_out = test_fn(query, key, value, bias, prng_key)
self.assertNestedAllClose(ref_out, test_out, atol=0.01)
jax.clear_caches()

@parameterized.product(
_TEST_CONFIGS,
backend=["cpu", "gpu", "tpu"],
bias_type=["causal", "sliding"],
input_dtype=[jnp.float32],
# TODO(hanzhi_zhou): support multi step gpu decoding.
step_len=[1],
)
@pytest.mark.for_8_devices
def test_decoding(
self,
batch,
seq_len,
num_heads,
num_kv_heads,
per_head_dim,
mesh,
mesh_axis_names,
backend,
bias_type,
input_dtype,
step_len,
):
if not is_supported_mesh_shape(mesh):
pytest.skip(reason=f"Unsupported mesh {mesh}.")

if bias_type == "causal":
bias = CausalAttentionBias(shape=(seq_len, seq_len))
else:
assert bias_type == "sliding"
bias = MaskFnAttentionBias(
mask=sliding_window_causal_mask(sliding_window_size=4), shape=(seq_len, seq_len)
)

with patch("axlearn.common.flash_attention.utils._interpret", return_value=True):
with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names):
test_fn = utils.flash_attention_implementation(backend)

query, key, value = _get_inputs(
batch=batch,
seq_len=seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads or num_heads,
per_head_dim=per_head_dim,
input_dtype=input_dtype,
)
prng_key = jax.random.PRNGKey(0)

fwd_out = test_fn(query, key, value, bias, prng_key)
# Limit generation length to 16 to save test time.
query_len = 16
query = query[:, :query_len]
fwd_out = fwd_out[:, :query_len]

decoding_output = []
for t in range(0, query_len, step_len):
if bias_type == "causal":
bias_step = CausalAttentionBias(
shape=(step_len, seq_len),
target_positions=jnp.full([batch], fill_value=t),
)
else:
assert bias_type == "sliding"
bias_step = MaskFnAttentionBias(
mask=sliding_window_causal_mask(sliding_window_size=4),
shape=(step_len, seq_len),
target_positions=jnp.full([batch], fill_value=t),
)

query_step = query[:, t : t + step_len]
decoding_out = test_fn(query_step, key, value, bias_step, prng_key)
decoding_output.append(decoding_out)
decoding_output = jnp.concatenate(decoding_output, axis=1)
self.assertNestedAllClose(fwd_out, decoding_output, atol=0.01)
jax.clear_caches()


if __name__ == "__main__":
absltest.main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ markers = [
"tpu: tests needing access to a TPU device.",
"high_cpu: tests that require a lot of CPU.",
"fp64: tests that require 64-bit floating point precision.",
"for_8_devices: tests that run on host platform device count of 8.",
"golden_config: golden config tests.",
"golden_init: golden init tests.",
"golden_regularizer: golden regularizer scale tests.",
Expand Down
6 changes: 5 additions & 1 deletion run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@ fi

UNQUOTED_PYTEST_FILES=$(echo $1 | tr -d "'")
pytest --durations=100 -v -n auto \
-m "not (gs_login or tpu or high_cpu or fp64)" ${UNQUOTED_PYTEST_FILES} \
-m "not (gs_login or tpu or high_cpu or fp64 or for_8_devices)" ${UNQUOTED_PYTEST_FILES} \
--dist worksteal &
TEST_PIDS[$!]=1

JAX_ENABLE_X64=1 pytest --durations=100 -v -n auto -v -m "fp64" --dist worksteal &
TEST_PIDS[$!]=1

XLA_FLAGS="--xla_force_host_platform_device_count=8" pytest --durations=100 -v \
-n auto -v -m "for_8_devices" --dist worksteal &
TEST_PIDS[$!]=1

# Use Bash 5.1's new wait -p feature to quit immediately if any subprocess fails to make error
# finding a bit easier.
while [ ${#TEST_PIDS[@]} -ne 0 ]; do
Expand Down
Loading