Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix Attention kernel #1610

Merged
merged 17 commits into from
Nov 22, 2024
Merged

Matrix Attention kernel #1610

merged 17 commits into from
Nov 22, 2024

Conversation

jagrit06
Copy link
Member

Proposed changes

  • Adds support for a fused matrix attention kernel that runs for query sequence lengths >= 32

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@jagrit06
Copy link
Member Author

Current results on M3 Max with short sequences and head dim = 64

python benchmarks/python/sdpa_bench.py 
  B,   qsl,   ksl, hdim, n_qh, n_kvh, tpose,   dtype, t_unfs, t_fuse, diff%
  1,    32,    32,   64,   32,    32,     0, float16,  0.034,  0.013, +150.49%
  1,    64,    64,   64,   32,    32,     0, float16,  0.018,  0.012, +43.78%
  1,   128,   128,   64,   32,    32,     0, float16,  0.024,  0.017, +43.82%
  1,   256,   256,   64,   32,    32,     0, float16,  0.049,  0.033, +47.54%
  1,   512,   512,   64,   32,    32,     0, float16,  0.141,  0.077, +83.23%
  1,  1024,  1024,   64,   32,    32,     0, float16,  0.474,  0.251, +88.75%
  1,  2048,  2048,   64,   32,    32,     0, float16,  1.400,  0.954, +46.81%
  1,  4096,  4096,   64,   32,    32,     0, float16,  5.424,  3.747, +44.75%
  1,    32,    32,   64,   32,    32,     0, float32,  0.029,  0.010, +177.74%
  1,    64,    64,   64,   32,    32,     0, float32,  0.019,  0.012, +54.88%
  1,   128,   128,   64,   32,    32,     0, float32,  0.025,  0.017, +47.18%
  1,   256,   256,   64,   32,    32,     0, float32,  0.053,  0.034, +52.88%
  1,   512,   512,   64,   32,    32,     0, float32,  0.171,  0.099, +71.86%
  1,  1024,  1024,   64,   32,    32,     0, float32,  0.588,  0.344, +70.62%
  1,  2048,  2048,   64,   32,    32,     0, float32,  2.195,  1.349, +62.76%
  1,  4096,  4096,   64,   32,    32,     0, float32,  8.739,  5.905, +48.00%

@jagrit06
Copy link
Member Author

Clean up, additional testing and adjustments are still in progress

@angeloskath
Copy link
Member

Fantastic speedups! The following is on the M2 air

  B,   qsl,   ksl, hdim, n_qh, n_kvh, tpose,   dtype, t_unfs, t_fuse, diff%
  1,    32,    32,   64,   32,    32,     0, float16,  0.036,  0.017, +112.35%
  1,    64,    64,   64,   32,    32,     0, float16,  0.038,  0.017, +125.08%
  1,   128,   128,   64,   32,    32,     0, float16,  0.041,  0.034, +20.16%
  1,   256,   256,   64,   32,    32,     0, float16,  0.109,  0.095, +14.68%
  1,   512,   512,   64,   32,    32,     0, float16,  0.390,  0.316, +23.21%
  1,  1024,  1024,   64,   32,    32,     0, float16,  1.474,  1.181, +24.84%
  1,  2048,  2048,   64,   32,    32,     0, float16,  5.663,  4.584, +23.54%
  1,  4096,  4096,   64,   32,    32,     0, float16,  22.423,  18.157, +23.50%
  1,  1024,  1024,   80,   32,    32,     0, float16,  2.165,  1.449, +49.45%
  1,  2048,  2048,   80,   32,    32,     0, float16,  8.531,  5.641, +51.23%
  1,  4096,  4096,   80,   32,    32,     0, float16,  33.792,  22.369, +51.07%
  1,  1024,  1024,  128,   32,    32,     0, float16,  2.287,  2.440, -6.30%
  1,  2048,  2048,  128,   32,    32,     0, float16,  8.939,  9.551, -6.41%
  1,  4096,  4096,  128,   32,    32,     0, float16,  35.482,  37.958, -6.52%

and this is on the M2 Max

  B,   qsl,   ksl, hdim, n_qh, n_kvh, tpose,   dtype, t_unfs, t_fuse, diff%
  1,    32,    32,   64,   32,    32,     0, float16,  0.031,  0.017, +79.55%
  1,    64,    64,   64,   32,    32,     0, float16,  0.029,  0.013, +120.58%
  1,   128,   128,   64,   32,    32,     0, float16,  0.023,  0.016, +42.49%
  1,   256,   256,   64,   32,    32,     0, float16,  0.044,  0.036, +23.91%
  1,   512,   512,   64,   32,    32,     0, float16,  0.112,  0.099, +13.03%
  1,  1024,  1024,   64,   32,    32,     0, float16,  0.384,  0.337, +13.80%
  1,  2048,  2048,   64,   32,    32,     0, float16,  1.479,  1.252, +18.11%
  1,  4096,  4096,   64,   32,    32,     0, float16,  5.781,  4.868, +18.77%
  1,  1024,  1024,   80,   32,    32,     0, float16,  0.546,  0.409, +33.42%
  1,  2048,  2048,   80,   32,    32,     0, float16,  2.115,  1.533, +37.94%
  1,  4096,  4096,   80,   32,    32,     0, float16,  8.330,  6.000, +38.83%
  1,  1024,  1024,  128,   32,    32,     0, float16,  0.609,  0.681, -10.56%
  1,  2048,  2048,  128,   32,    32,     0, float16,  2.374,  2.583, -8.09%
  1,  4096,  4096,  128,   32,    32,     0, float16,  9.351,  10.098, -7.40%

@southkorea2013
Copy link

southkorea2013 commented Nov 21, 2024

Hello Angelos&Jagrit,

Is this can reduce video2text inference memory requirement and improve the generation speed? Is it equivalent to flash_attention_2 ?

Thanks,
Nan

@awni
Copy link
Member

awni commented Nov 21, 2024

M1 Max:

  B,   qsl,   ksl, hdim, n_qh, n_kvh, tpose,   dtype, t_unfs, t_fuse, diff%
  1,    32,    32,   64,   32,    32,     0, float16,  0.032,  0.018, +75.62%
  1,    64,    64,   64,   32,    32,     0, float16,  0.031,  0.016, +97.35%
  1,   128,   128,   64,   32,    32,     0, float16,  0.032,  0.020, +62.86%
  1,   256,   256,   64,   32,    32,     0, float16,  0.054,  0.042, +29.19%
  1,   512,   512,   64,   32,    32,     0, float16,  0.136,  0.121, +12.45%
  1,  1024,  1024,   64,   32,    32,     0, float16,  0.450,  0.419, +7.53%
  1,  2048,  2048,   64,   32,    32,     0, float16,  1.749,  1.579, +10.77%
  1,  4096,  4096,   64,   32,    32,     0, float16,  6.889,  6.193, +11.24%
  1,  1024,  1024,   80,   32,    32,     0, float16,  0.671,  0.511, +31.34%
  1,  2048,  2048,   80,   32,    32,     0, float16,  2.593,  1.943, +33.46%
  1,  4096,  4096,   80,   32,    32,     0, float16,  10.218,  7.643, +33.69%
  1,  1024,  1024,  128,   32,    32,     0, float16,  0.743,  0.845, -12.04%
  1,  2048,  2048,  128,   32,    32,     0, float16,  2.896,  3.293, -12.04%
  1,  4096,  4096,  128,   32,    32,     0, float16,  11.496,  12.853, -10.56%

@jagrit06
Copy link
Member Author

Hello Angelos&Jagrit,

Is this can reduce video2text inference memory requirement and improve the generation speed? Is it equivalent to flash_attention_2 ?

Thanks, Nan

This should help memory requirements for long sequences. As for generation speed, we expect to keep doing updates later, but at the moment, it should be useful for smaller head-dims

@jagrit06 jagrit06 marked this pull request as ready for review November 21, 2024 19:21

o.set_data(
allocator::malloc_or_wait(o.nbytes()),
o.data_size(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jagrit06 the bug is here. Using data_size before the array has data set is invalid.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! I didn't realize it since nbytes() works - I've pushed a fix

@@ -93,6 +107,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should optionally put this in the JIT (can be a follow on). Right now the behavior is just to build if the jit is not enabled. If the JIT is enabled then the kernel won't be there. You can move it out of this condition to get it to pass the tests for now.

@awni awni requested a review from angeloskath November 21, 2024 21:14
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel looks fantastic. The tile ops came out great!

I left a tiny comment for future consideration.

A general comment is that a lot of stuff have been copied from steel matmul and are not used as far as I can tell like the axpby transforms etc. I assume this will be refactored later.


/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if applying the scale could be done while reading from device memory.

Basically read into registers, apply op in place and write into thread group memory.

@jagrit06
Copy link
Member Author

The kernel looks fantastic. The tile ops came out great!

I left a tiny comment for future consideration.

A general comment is that a lot of stuff have been copied from steel matmul and are not used as far as I can tell like the axpby transforms etc. I assume this will be refactored later.

Yeah, a lot of refactoring and then merging with the steel matmuls is still needed - will keep working on it!

@jagrit06 jagrit06 merged commit 02bec0b into main Nov 22, 2024
5 checks passed
@jagrit06 jagrit06 deleted the digani/mat-attn branch November 22, 2024 18:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants