-
Notifications
You must be signed in to change notification settings - Fork 392
Flash Attention for Neuron #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| from functools import partial | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import custom_vjp | ||
|
|
||
| lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1 | ||
|
|
||
| @partial(custom_vjp, nondiff_argnums=(4, 5)) | ||
| def flash_attention(query, key, value, bias, causal, softmax_scale): | ||
| out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale) | ||
| return out | ||
|
|
||
|
|
||
| def _mha_forward(query, key, value, bias, causal, softmax_scale): | ||
| # Get the batch size, sequence lengths, number of heads, and hidden dimension | ||
| batch_size, q_seq_len, num_heads, d_model = query.shape | ||
|
|
||
| # Transpose the query, key, and value tensors | ||
| q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len] | ||
| k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len] | ||
| v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model] | ||
|
|
||
| import neuronxcc.nki.language as nl | ||
apoorvtintin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from neuronxcc.nki.kernels.attention import flash_fwd | ||
| seed = jnp.array([1]) | ||
|
|
||
| # Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads | ||
| if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
| grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
| else: | ||
| grid = batch_size, num_heads | ||
|
|
||
| if bias != None: | ||
| assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
| attn_output, lse = flash_fwd[grid]( | ||
| q, | ||
| k, | ||
| v, | ||
| seed, | ||
| bias, | ||
| use_causal_mask=causal, | ||
| softmax_scale=softmax_scale, | ||
| mixed_precision=True, | ||
| dropout_p=0.0, | ||
| ) | ||
| else: | ||
| attn_output, lse = flash_fwd[grid]( | ||
| q, | ||
| k, | ||
| v, | ||
| seed, | ||
| use_causal_mask=causal, | ||
| softmax_scale=softmax_scale, | ||
| mixed_precision=True, | ||
| dropout_p=0.0, | ||
| ) | ||
| # Transpose the output back to the original shape | ||
| attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] | ||
|
|
||
| return attn_output, (lse, attn_output, q, k, v, bias) | ||
|
|
||
|
|
||
| def _mha_backward(causal, softmax_scale, res, d_attn_output): | ||
| lse, o, q, k, v, bias = res | ||
| batch_size, num_heads, d_model, seq_len = q.shape | ||
|
|
||
| # Transpose the input tensors | ||
| o = o.transpose(0, 2, 3, 1) | ||
| dy = d_attn_output.transpose(0, 2, 3, 1) | ||
|
|
||
| # Transpose v tensor | ||
| v = jnp.transpose(v, axes=(0, 1, 3, 2)) | ||
| seed = jnp.array([1]) | ||
|
|
||
| from neuronxcc.nki.kernels.attention import flash_attn_bwd | ||
| import neuronxcc.nki.language as nl | ||
|
|
||
| # Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads | ||
| if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
| grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
| else: | ||
| grid = batch_size, num_heads | ||
|
|
||
| if bias != None: | ||
| assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
| d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
| q, | ||
| k, | ||
| v, | ||
| o, | ||
| dy, | ||
| lse, | ||
| seed, | ||
| bias, | ||
| use_causal_mask=causal, | ||
| mixed_precision=True, | ||
| dropout_p=0.0, | ||
| softmax_scale=softmax_scale, | ||
| ) | ||
| else: | ||
| d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
| q, | ||
| k, | ||
| v, | ||
| o, | ||
| dy, | ||
| lse, | ||
| seed, | ||
| use_causal_mask=causal, | ||
| mixed_precision=True, | ||
| dropout_p=0.0, | ||
| softmax_scale=softmax_scale, | ||
| ) | ||
|
|
||
| # Batch seq_len heads, head_dim | ||
| # Transpose the gradients back to the original shape | ||
| d_query = d_query.transpose(0, 3, 1, 2) | ||
| d_key = d_key.transpose(0, 3, 1, 2) | ||
| d_value = d_value.transpose(0, 3, 1, 2) | ||
|
|
||
| return d_query, d_key, d_value, None | ||
|
|
||
|
|
||
| flash_attention.defvjp(_mha_forward, _mha_backward) | ||
132 changes: 132 additions & 0 deletions
132
axlearn/common/flash_attention/neuron_attention_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| # Copyright © 2024 Amazon Inc. | ||
| """Tests for Flash attention on Neuron. Tested on trn1 & trn2.""" | ||
| import functools | ||
|
|
||
| import chex | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import pytest | ||
|
|
||
| from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
| from axlearn.common.flash_attention.utils import mha_reference | ||
|
|
||
|
|
||
| if jax.default_backend() != "neuron": | ||
| pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "batch_size,seq_len,num_heads,per_head_dim", | ||
| [ | ||
| (1, 2048, 1, 64), | ||
| (2, 2048, 2, 64), | ||
| (1, 2048, 1, 128), | ||
| (2, 2048, 2, 128), | ||
| (1, 2048, 8, 128), | ||
| (2, 2048, 8, 128), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("use_fwd", [True, False]) | ||
| @pytest.mark.parametrize("causal", [True, False]) | ||
| @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) | ||
| def test_fwd_against_ref( | ||
| batch_size: int, | ||
| seq_len: int, | ||
| num_heads: int, | ||
| per_head_dim: int, | ||
| use_fwd: bool, | ||
| causal: bool, | ||
| input_dtype: jnp.dtype, | ||
| ): | ||
| sm_scale = 1.0 / (per_head_dim**0.5) | ||
| k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) | ||
| q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
| k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
| v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
|
||
| bias = None | ||
| segment_ids = None | ||
|
|
||
| if use_fwd: | ||
|
|
||
| @jax.jit | ||
| def impl(q, k, v, bias): | ||
| fn = functools.partial( | ||
| flash_attention, | ||
| causal=causal, | ||
| softmax_scale=sm_scale, | ||
| ) | ||
| out, _ = jax.vjp(fn, q, k, v, bias) | ||
| return out | ||
|
|
||
| else: | ||
| impl = functools.partial( | ||
| flash_attention, | ||
| causal=causal, | ||
| softmax_scale=sm_scale, | ||
| ) | ||
|
|
||
| o = impl(q, k, v, bias) | ||
| o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) | ||
| chex.assert_trees_all_close(o, o_ref, atol=0.05) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "batch_size,num_heads,seq_len,per_head_dim", | ||
| [ | ||
| (1, 1, 2048, 64), | ||
| (2, 2, 2048, 64), | ||
| (1, 1, 2048, 128), | ||
| (2, 2, 2048, 128), | ||
| (1, 8, 2048, 128), | ||
| (2, 8, 2048, 128), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("causal", [True, False]) | ||
| @pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) | ||
| def test_bwd_against_ref( | ||
| batch_size: int, | ||
| num_heads: int, | ||
| seq_len: int, | ||
| per_head_dim: int, | ||
| causal: bool, | ||
| input_dtype: jnp.dtype, | ||
| ): | ||
| sm_scale = 1.0 / (per_head_dim**0.5) | ||
| q = jax.random.normal( | ||
| jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
| ) | ||
| k = jax.random.normal( | ||
| jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
| ) | ||
| v = jax.random.normal( | ||
| jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
| ) | ||
|
|
||
| bias = None | ||
| segment_ids = None | ||
|
|
||
| def fn(q, k, v, bias): | ||
| return flash_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| bias, | ||
| causal=causal, | ||
| softmax_scale=sm_scale, | ||
| ).sum() | ||
|
|
||
| def ref_fn(q, k, v, bias, segment_ids): | ||
| return mha_reference( | ||
| q, | ||
| k, | ||
| v, | ||
| bias, | ||
| segment_ids, | ||
| causal=causal, | ||
| softmax_scale=sm_scale, | ||
| ).sum() | ||
|
|
||
| jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias) | ||
| jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) | ||
| chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get a support for segment ID and dropout as well? Both are quite needed nowadays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Segment ID support is in progress and will be added soon