-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matrix Attention kernel #1610
Matrix Attention kernel #1610
Conversation
Current results on M3 Max with short sequences and head dim = 64
|
Clean up, additional testing and adjustments are still in progress |
Fantastic speedups! The following is on the M2 air
and this is on the M2 Max
|
Hello Angelos&Jagrit, Is this can reduce video2text inference memory requirement and improve the generation speed? Is it equivalent to flash_attention_2 ? Thanks, |
M1 Max:
|
This should help memory requirements for long sequences. As for generation speed, we expect to keep doing updates later, but at the moment, it should be useful for smaller head-dims |
a1d1d49
to
dd427f6
Compare
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal
Outdated
Show resolved
Hide resolved
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal
Outdated
Show resolved
Hide resolved
|
||
o.set_data( | ||
allocator::malloc_or_wait(o.nbytes()), | ||
o.data_size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jagrit06 the bug is here. Using data_size
before the array has data set is invalid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch! I didn't realize it since nbytes()
works - I've pushed a fix
@@ -93,6 +107,7 @@ if(NOT MLX_METAL_JIT) | |||
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) | |||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) | |||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) | |||
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should optionally put this in the JIT (can be a follow on). Right now the behavior is just to build if the jit is not enabled. If the JIT is enabled then the kernel won't be there. You can move it out of this condition to get it to pass the tests for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel looks fantastic. The tile ops came out great!
I left a tiny comment for future consideration.
A general comment is that a lot of stuff have been copied from steel matmul and are not used as far as I can tell like the axpby transforms etc. I assume this will be refactored later.
|
||
/* Apply operation to threadgroup without bound checking */ | ||
template <typename UnaryOp> | ||
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if applying the scale could be done while reading from device memory.
Basically read into registers, apply op in place and write into thread group memory.
9702b37
to
716277d
Compare
Yeah, a lot of refactoring and then merging with the steel matmuls is still needed - will keep working on it! |
Proposed changes
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes