the high level architecture taken from the paper

=== profiling manual attention ===
...
Self CPU time total: 52.389ms
Self CUDA time total: 52.545ms
=== profiling minimal flash attention ===
...
Self CPU time total: 11.452ms
Self CUDA time total: 3.908ms
-
the thread-per-row simplification makes the matrix multiplications very slow. This is probably why for longer sequences and larger block sizes, this gets slower than the manual implementation. -
in the inner loop, i assign each thread to a row of the output matrix. This differs from the original implementation. -
Q,K,Vs are in float32 which is unlike the original implementation which uses float16.