Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Flash Attention implementation (forward + backward) #1

Closed
wants to merge 72 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
f7bcfb0
cuda: add flash attention + test
FSSRepo Jan 17, 2024
e53de28
fix compilation
FSSRepo Jan 18, 2024
a1c004e
ggml : add ggml_flash_attn_ext API
ggerganov Jan 18, 2024
fa7ebcc
ggml : fix GQA support in ggml_flash_attn_ext
ggerganov Jan 19, 2024
09db1a7
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 19, 2024
fded2e6
apply suggestions
FSSRepo Jan 20, 2024
c3cdfff
Merge branch 'master' into gg/flash-attn
ggerganov Jan 20, 2024
a9681fe
ggml : online attention (CPU)
ggerganov Jan 20, 2024
1173f49
metal : initial implementation
ggerganov Jan 20, 2024
528da75
metal : f16 precision
ggerganov Jan 21, 2024
52ae085
metal : reduce branches
ggerganov Jan 21, 2024
b973258
metal : specialize for head size
ggerganov Jan 21, 2024
8cde449
wip : 8 rows per simd group
ggerganov Jan 21, 2024
f31955f
wip : 4 rows per simd group
ggerganov Jan 21, 2024
a4b6341
wip : template for rows per warp
ggerganov Jan 21, 2024
77d08f3
metal : parallelize across KV size
ggerganov Jan 21, 2024
17720fa
metal : parallel reduce across heads
ggerganov Jan 21, 2024
a689b02
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 23, 2024
6374bc5
cuda: port metal version flash_attn_ext
FSSRepo Jan 23, 2024
6416821
fix equivalent fp16 math functions, compiler error 'undefined'
FSSRepo Jan 24, 2024
972c2ad
use half2 instead half4
FSSRepo Jan 24, 2024
0fc36d8
match to metal impl
FSSRepo Jan 24, 2024
1446a12
metal : efficient flash_attn_f16 implementation
ggerganov Jan 23, 2024
d917746
metal : avoid redundant loads of the attention
ggerganov Jan 25, 2024
432ad04
metal : scale and mask in matrix form
ggerganov Jan 25, 2024
40ea8cd
metal : fix comment
ggerganov Jan 25, 2024
78da338
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 25, 2024
f9ca5dc
llama : avoid ggml_cast, use F32 query
ggerganov Jan 25, 2024
6e7cb0e
update implementation
FSSRepo Jan 25, 2024
6fea843
metal : add parallel reduce version (disabled)
ggerganov Jan 25, 2024
0a481fe
integrate tensor cores
FSSRepo Jan 27, 2024
7cea973
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 27, 2024
2455a8d
update impl
FSSRepo Jan 27, 2024
b3dd7d9
Merge branch 'master' into gg/flash-attn
ggerganov Jan 28, 2024
77f6976
metal : move output into local memory + optimize
ggerganov Jan 28, 2024
ecc466a
metal : add tests, fix scaling, support C > 32
ggerganov Jan 28, 2024
3a428a1
metal : improve precision
ggerganov Jan 28, 2024
8612864
ggml : fix f16 mad
ggerganov Jan 28, 2024
0ad44ba
Merge branch 'master' into gg/flash-attn
ggerganov Jan 28, 2024
134c81c
metal : minor
ggerganov Jan 28, 2024
1db22d7
metal : support Q > 8
ggerganov Jan 28, 2024
4794821
tests : add ATTN tests
ggerganov Jan 29, 2024
abeaf0d
metal : disable buffer allocation logs
ggerganov Jan 29, 2024
c6c1132
tests : more
ggerganov Jan 29, 2024
5fcb9c1
metal : faster inner loop for C == 32
ggerganov Jan 29, 2024
a1d5a12
fix compiler error
FSSRepo Jan 29, 2024
7980178
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 29, 2024
d073e4f
metal : fix array initialization
ggerganov Jan 30, 2024
78df552
tests : ifdef
ggerganov Jan 30, 2024
3d03bcb
Merge branch 'master' into gg/flash-attn
ggerganov Jan 30, 2024
3b0f74b
latest kernel update, wrong values
FSSRepo Jan 30, 2024
2ddc9bb
Merge branch 'master' into gg/flash-attn
ggerganov Jan 31, 2024
b1479df
fix kernel
FSSRepo Jan 31, 2024
8ad92dc
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
ggerganov Jan 31, 2024
0afe47f
fix naive implementation
FSSRepo Jan 31, 2024
3df0b8d
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 31, 2024
fd878f7
cuda: mask as fp16
FSSRepo Jan 31, 2024
71b69aa
cuda : fix flash_attn kernel to produce same results as CPU
ggerganov Feb 1, 2024
2c04bee
cuda : avoid extra QxQ matrix in shared memory
ggerganov Feb 1, 2024
9a5c2a1
cuda : switch to F16 scalars + tune warps for RTX 2060
ggerganov Feb 1, 2024
ac26f27
cuda : increase C to 128 for better performance
ggerganov Feb 1, 2024
43f7156
Merge pull request #3 from ggerganov/flash-attn-cuda
FSSRepo Feb 1, 2024
9240a84
fix mask nullptr
FSSRepo Feb 1, 2024
8d7a606
don't require LLAMA_CUDA_F16 to compile
FSSRepo Feb 1, 2024
19e0b8e
#ifdef -> #if + fix check -inf
FSSRepo Feb 1, 2024
cae985c
cmake: remove unused changes
FSSRepo Feb 1, 2024
53621e3
refactor flash_attn function + improve tests
FSSRepo Feb 1, 2024
674d5ac
unroll 2 loops, int64_t -> int, 309 µs
JohannesGaessler Feb 3, 2024
8b51ab4
Merge pull request #4 from Pints-App/jg/flash-attn-cuda
FSSRepo Feb 3, 2024
a1f9ffe
bring optimizations from gg/flash-attn
FSSRepo Feb 3, 2024
ba7699d
Merge branch 'flash-attn-cuda' of https://github.com/Pints-App/llama.…
FSSRepo Feb 3, 2024
f659f57
fix merge conflicts
FSSRepo Feb 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 191 additions & 140 deletions ggml-metal.m

Large diffs are not rendered by default.

238 changes: 227 additions & 11 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1959,11 +1959,48 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}

typedef void (flash_attn_ext_f16_t)(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]);

template<int64_t D, int64_t R> // head size, rows per threadgroup
kernel void kernel_flash_attn_ext_f16(
device const half * q,
device const half * k,
device const half * v,
device const float * mask,
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
Expand All @@ -1973,21 +2010,200 @@ kernel void kernel_flash_attn_ext_f16(
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant float & scale,
constant float & scale,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// TODO: implement
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint nsg = ntg.y; // number of simdgroups
const uint tph = N_SIMDWIDTH/R; // threads per head

const int64_t iq3 = tgpig[2];
const int64_t iq2 = tgpig[1]*R + tiisg/tph;
const int64_t iq1 = tgpig[0];

if (iq2 >= ne02) {
return;
}

// assume K and V are same shape
const int64_t ne22 = ne12;
const int64_t ne23 = ne13;

const uint64_t nb21 = nb11;
const uint64_t nb22 = nb12;
const uint64_t nb23 = nb13;

// broadcast
const int64_t rk2 = ne02/ne12;
const int64_t rk3 = ne03/ne13;

const int64_t rv2 = ne02/ne22;
const int64_t rv3 = ne03/ne23;

// k indices
const int64_t ik2 = iq2 / rk2;
const int64_t ik3 = iq3 / rk3;

// v indices
const int64_t iv2 = iq2 / rv2;
const int64_t iv3 = iq3 / rv3;

const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;

device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr;

const int64_t D4 = D/4;

threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D);
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D);
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D);

const uint tiih = tiisg%tph; // thread index in head
const uint hiisg = tiisg/tph; // head index in simdgroup

// load R heads from Q to shared memory
for (int64_t i = 0; i < D4/tph; ++i) {
if (sgitg == 0) {
pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
}

ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

half S = 0.0h;
half M = -INFINITY;

for (int64_t ic = sgitg; ic < ne11; ic += nsg) {
const half mv = mp ? mp[ic] : 0.0h;
if (mv == -INFINITY) {
continue;
}

device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));

half4 s4 = 0.0h;

#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
}

ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);

simdgroup_barrier(mem_flags::mem_threadgroup);

if (tiih == 0) {
half s = 0.0h;

#pragma unroll
for (int64_t i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i];
}

s = s*scale + mv;

const half m = M;

M = max(M, s);

const half ms = exp(m - M);
const half vs = exp(s - M);

S = S*ms + vs;

ss[2*hiisg + 0] = ms;
ss[2*hiisg + 1] = vs;
}

simdgroup_barrier(mem_flags::mem_threadgroup);

const half ms = ss[2*hiisg + 0];
const half vs = ss[2*hiisg + 1];

#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
}
}

if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// reduce the warps
if (sgitg == 0) {
for (int64_t sg = 1; sg < nsg; ++sg) {
const half S0 = ss[ 2*hiisg + 0];
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];

const half M0 = ss[ 2*hiisg + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];

M = max(M0, M1);

const half ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M);

S = S0*ms0 + S1*ms1;

if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}

for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
}
}

for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
}
}

simdgroup_barrier(mem_flags::mem_threadgroup);

// dst indices
const int64_t i1 = iq1;
const int64_t i2 = iq2;
const int64_t i3 = iq3;

device float4 * dst4 = (device float4 *) dst;

if (sgitg == 0) {
for (int64_t i = 0; i < D4/tph; ++i) {
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih];
}
}
}

template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>;

kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
Expand Down
Loading