Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add Flash Attention #5021

Merged
merged 145 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
145 commits
Select commit Hold shift + click to select a range
a1c004e
ggml : add ggml_flash_attn_ext API
ggerganov Jan 18, 2024
fa7ebcc
ggml : fix GQA support in ggml_flash_attn_ext
ggerganov Jan 19, 2024
c3cdfff
Merge branch 'master' into gg/flash-attn
ggerganov Jan 20, 2024
a9681fe
ggml : online attention (CPU)
ggerganov Jan 20, 2024
1173f49
metal : initial implementation
ggerganov Jan 20, 2024
528da75
metal : f16 precision
ggerganov Jan 21, 2024
52ae085
metal : reduce branches
ggerganov Jan 21, 2024
b973258
metal : specialize for head size
ggerganov Jan 21, 2024
8cde449
wip : 8 rows per simd group
ggerganov Jan 21, 2024
f31955f
wip : 4 rows per simd group
ggerganov Jan 21, 2024
a4b6341
wip : template for rows per warp
ggerganov Jan 21, 2024
77d08f3
metal : parallelize across KV size
ggerganov Jan 21, 2024
17720fa
metal : parallel reduce across heads
ggerganov Jan 21, 2024
1446a12
metal : efficient flash_attn_f16 implementation
ggerganov Jan 23, 2024
d917746
metal : avoid redundant loads of the attention
ggerganov Jan 25, 2024
432ad04
metal : scale and mask in matrix form
ggerganov Jan 25, 2024
40ea8cd
metal : fix comment
ggerganov Jan 25, 2024
f9ca5dc
llama : avoid ggml_cast, use F32 query
ggerganov Jan 25, 2024
6fea843
metal : add parallel reduce version (disabled)
ggerganov Jan 25, 2024
b3dd7d9
Merge branch 'master' into gg/flash-attn
ggerganov Jan 28, 2024
77f6976
metal : move output into local memory + optimize
ggerganov Jan 28, 2024
ecc466a
metal : add tests, fix scaling, support C > 32
ggerganov Jan 28, 2024
3a428a1
metal : improve precision
ggerganov Jan 28, 2024
8612864
ggml : fix f16 mad
ggerganov Jan 28, 2024
0ad44ba
Merge branch 'master' into gg/flash-attn
ggerganov Jan 28, 2024
134c81c
metal : minor
ggerganov Jan 28, 2024
1db22d7
metal : support Q > 8
ggerganov Jan 28, 2024
4794821
tests : add ATTN tests
ggerganov Jan 29, 2024
abeaf0d
metal : disable buffer allocation logs
ggerganov Jan 29, 2024
c6c1132
tests : more
ggerganov Jan 29, 2024
5fcb9c1
metal : faster inner loop for C == 32
ggerganov Jan 29, 2024
d073e4f
metal : fix array initialization
ggerganov Jan 30, 2024
78df552
tests : ifdef
ggerganov Jan 30, 2024
3d03bcb
Merge branch 'master' into gg/flash-attn
ggerganov Jan 30, 2024
2ddc9bb
Merge branch 'master' into gg/flash-attn
ggerganov Jan 31, 2024
8ad92dc
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
ggerganov Jan 31, 2024
910b15b
ggml : fix ggml_soft_max mask requirement
ggerganov Feb 1, 2024
2e46013
cuda : fix soft_max to use correct mask size
ggerganov Feb 1, 2024
5a19a9f
cuda : add flash_attn kernel (wip)
ggerganov Feb 1, 2024
41d136b
Merge branch 'master' into gg/flash-attn
ggerganov Feb 1, 2024
56e45a2
metal : optimize softmax for C > 32
ggerganov Feb 1, 2024
cda5a60
metal : optimize softmax
ggerganov Feb 1, 2024
c6769b9
tests : minor fix
ggerganov Feb 1, 2024
db1f3c4
cuda : avoid zeroing fragments
ggerganov Feb 1, 2024
12eaa22
tests : update dims
ggerganov Feb 2, 2024
b68a112
cuda : fix __hisinf() result check
ggerganov Feb 2, 2024
b150abe
cuda : avoid warp_reduce for smax
ggerganov Feb 3, 2024
7c34655
cuda : use int instead of int64_t
ggerganov Feb 3, 2024
1f8a592
cuda : make loops use the same loop values
ggerganov Feb 3, 2024
92472ea
cuda : unroll some of the loops
ggerganov Feb 3, 2024
c51f27c
cuda : avoid __hisinf branches
ggerganov Feb 3, 2024
b958151
cuda : use half2 in softmax
ggerganov Feb 3, 2024
a7b4715
cuda : switch to 1 warp for bs > 16
ggerganov Feb 3, 2024
3b1c4e7
cuda : speed-up reduce part of the kernel
ggerganov Feb 3, 2024
5b263dd
cuda : unroll Q*K^T loop
ggerganov Feb 3, 2024
e04ff39
cuda : fix -INF block check
ggerganov Feb 3, 2024
cfd9732
cuda : simplify softmax
ggerganov Feb 3, 2024
ef68fac
cuda : fix matrix names
ggerganov Feb 3, 2024
1846e92
cuda : minor
ggerganov Feb 4, 2024
6875997
Merge branch 'master' into gg/flash-attn
ggerganov Feb 12, 2024
31109ca
Merge branch 'master' into gg/flash-attn
ggerganov Feb 19, 2024
f249c99
llama : adapt to F16 KQ_pos
ggerganov Feb 19, 2024
02a645e
Merge branch 'master' into gg/flash-attn
ggerganov Mar 3, 2024
6aefd11
llama : adapt new models to F16 KQ_mask
ggerganov Mar 3, 2024
e307882
Merge branch 'master' into gg/flash-attn
ggerganov Mar 4, 2024
58c7f61
ggml : fix F16 store (ARM NEON)
ggerganov Mar 4, 2024
9495d39
Merge branch 'master' into gg/flash-attn
ggerganov Mar 22, 2024
3a468e6
llama : fix type of KQ_mask and KQ_pos
ggerganov Mar 22, 2024
0953212
ggml : fix CPU soft_max
ggerganov Mar 22, 2024
e425810
tests : add hs=256
ggerganov Mar 24, 2024
013721d
Merge branch 'master' into gg/flash-attn
ggerganov Mar 27, 2024
6be02b5
cuda : fix build
ggerganov Mar 27, 2024
57c03b7
metal : improve perf via smaller int registers
ggerganov Mar 28, 2024
3e318e7
Merge branch 'master' into gg/flash-attn
ggerganov Mar 28, 2024
08e69c5
cuda : adapt soft_max to F16 mask and pos
ggerganov Mar 28, 2024
75aa7b4
CUDA: faster FlashAttention, kernel for bs == 1
JohannesGaessler Mar 29, 2024
d59ac67
16 cols for Phi-2
JohannesGaessler Mar 30, 2024
81da919
no vec for hs, no hs==256 ncols==32 for Volta
JohannesGaessler Mar 30, 2024
269374e
adjust kernel selection logic
JohannesGaessler Mar 31, 2024
cca6d02
4 warps, 256 stride for all D
JohannesGaessler Mar 31, 2024
68d793b
no ncols == 64
JohannesGaessler Apr 1, 2024
3f777ac
Multiple parallel blocks for batch size 1
JohannesGaessler Apr 1, 2024
e1ecd3b
fix compile warnings
JohannesGaessler Apr 2, 2024
bb0d51a
fix excessive KQ_b loads
JohannesGaessler Apr 2, 2024
c63dfdf
fix cmake build
JohannesGaessler Apr 2, 2024
ee19a4a
fix KV cache padding, NaN from INFINITY (#6438)
JohannesGaessler Apr 2, 2024
89961de
Merge branch 'master' into gg/flash-attn
ggerganov Apr 5, 2024
2c41180
Merge branch 'master' into gg/flash-attn
ggerganov Apr 17, 2024
599ce84
llama : flash_attn cparam + fix defrag
ggerganov Apr 17, 2024
4053857
server: support flash_attn param
phymbert Apr 17, 2024
5668c79
server: bench: enable flash_attn param
phymbert Apr 17, 2024
34f93bb
CUDA: refactor host code, dyn. par. blocks
JohannesGaessler Apr 9, 2024
6a3b842
fix flash_attn_vec_f16 race condition
JohannesGaessler Apr 13, 2024
ef9e159
flush softmax exp below threshold to 0
JohannesGaessler Apr 15, 2024
a5b0e2d
store temp KQ in registers
JohannesGaessler Apr 16, 2024
0bc67dd
Calculate KQ as FP32 if KQV has GGML_PREC_F32
JohannesGaessler Apr 16, 2024
2f538b9
Add __hgt2_mask implementation for CUDA 11
JohannesGaessler Apr 17, 2024
87968de
fix KQ FP32 precision fpr parallel_blocks > 1
JohannesGaessler Apr 17, 2024
260cdb2
llama-bench : add -fa,--flash-attn arg
ggerganov Apr 18, 2024
105332c
metal : add BS=1 kernel for flash attention (#6508)
ggerganov Apr 18, 2024
fa9e8c6
Merge branch 'master' into gg/flash-attn
ggerganov Apr 18, 2024
c16a7c2
metal : use F32 attention accumulators
ggerganov Apr 18, 2024
9ca8698
batched-bench : add fattn arg
ggerganov Apr 18, 2024
74d57f9
llama : simplify llama_build_kv_store
ggerganov Apr 19, 2024
1db66c1
Merge branch 'master' into gg/flash-attn
ggerganov Apr 19, 2024
e32b281
llama : adapt build_olmo to changes
ggerganov Apr 19, 2024
703c6e6
ggml : fix arm fp16 store on windows
ggerganov Apr 19, 2024
97eaece
metal : clean-up
ggerganov Apr 19, 2024
1a88565
metal : clean-up kernel code
ggerganov Apr 19, 2024
bc34616
metal : minor
ggerganov Apr 19, 2024
29f6ad8
Merge branch 'master' into gg/flash-attn
ggerganov Apr 19, 2024
5294542
tests : remove benchmarks
ggerganov Apr 19, 2024
3badef1
ggml : fix avx512 const correctness
ggerganov Apr 19, 2024
871fcb6
ggml : fix soft_max with bias on CPU
ggerganov Apr 19, 2024
a39217d
common : print --flash-attn in help
ggerganov Apr 22, 2024
cb76d74
ggml : fix num dimensions in ggml_flash_attn_ext
ggerganov Apr 22, 2024
c11d05f
llama : force disable flash attention for incompatible models
ggerganov Apr 22, 2024
f725ca9
ggml : ggml_soft_max support F16/F32 mask/pos
ggerganov Apr 22, 2024
5408d55
cuda : uint -> uint32_t
ggerganov Apr 22, 2024
c70bfd7
cuda : "constexpr dim3" -> "const dim3"
ggerganov Apr 22, 2024
c129369
cuda : try to fix __hgt2_mask
ggerganov Apr 22, 2024
3864eea
ggml : add TODO's for F16/F32 mask/pos support in other backends
ggerganov Apr 23, 2024
78d363b
llama : replace bool need_kq_pos with use_alibi
ggerganov Apr 23, 2024
19e8982
llama : prep ALiBi support for BERT models
ggerganov Apr 23, 2024
56657e5
llama : fix n_batch requirements
ggerganov Apr 23, 2024
d228bf8
cont
ggerganov Apr 23, 2024
751591d
server : add help for --flash-attn arg
ggerganov Apr 23, 2024
8937ec5
Merge branch 'master' into gg/flash-attn
ggerganov Apr 24, 2024
ce281b9
llama : disable FA for AMD
ggerganov Apr 24, 2024
1f77f49
Merge branch 'master' into gg/flash-attn
ggerganov Apr 25, 2024
ff2c64a
tests : remove TMP_ATTN_BENCH
ggerganov Apr 25, 2024
cb3547a
Merge branch 'master' into gg/flash-attn
ggerganov Apr 25, 2024
1fd5bc3
llama : support save/load state with FA enabled
ggerganov Apr 25, 2024
09d0381
Merge branch 'master' into gg/flash-attn
ggerganov Apr 25, 2024
ac1c6d9
ci : add CUDA save-load-state tests
ggerganov Apr 25, 2024
c225609
llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggerganov Apr 25, 2024
bab346b
llama : fix copy-paste errors, add TODO
ggerganov Apr 25, 2024
0fc5c5e
llama : disallow incompatible states
ggerganov Apr 25, 2024
1e590ac
llama : update llama_state_get_size after v_trans field
ggerganov Apr 25, 2024
4f4c024
metal : remove tmp log
ggerganov Apr 25, 2024
9e38760
llama : add static reminder for llama_state_get_size
ggerganov Apr 25, 2024
a1616e9
Merge branch 'master' into gg/flash-attn
ggerganov Apr 29, 2024
ca0275c
Merge branch 'master' into gg/flash-attn
ggerganov Apr 29, 2024
e180fcd
metal : fix max nsg
ggerganov Apr 30, 2024
c240ae2
ci : fix arg order
ggerganov Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ function gg_run_open_llama_3b_v2 {

(time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log

(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log

function check_ppl {
qnt="$1"
Expand Down Expand Up @@ -517,7 +518,10 @@ function gg_run_open_llama_7b_v2 {

(time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log

(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state -fa -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state -fa -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log

function check_ppl {
qnt="$1"
Expand Down
7 changes: 7 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cont_batching = true;
return true;
}
if (arg == "-fa" || arg == "--flash-attn") {
params.flash_attn = true;
return true;
}
if (arg == "--color") {
params.use_color = true;
return true;
Expand Down Expand Up @@ -1494,6 +1498,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n");
if (llama_supports_mlock()) {
Expand Down Expand Up @@ -1866,6 +1871,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
Expand Down Expand Up @@ -2703,6 +2709,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ struct gpt_params {
bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
Expand Down
28 changes: 17 additions & 11 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
gpt_params params;

if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
return 1 ;
Expand All @@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
int n_kv_max = 2048;
int n_batch = 2048;
int n_ubatch = 512;
bool flash_attn = false;
int is_pp_shared = 0;
int n_gpu_layers = 0;

Expand All @@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
}

if (argc >= 6) {
is_pp_shared = std::atoi(argv[5]);
flash_attn = std::atoi(argv[5]);
}

if (argc >= 7) {
n_gpu_layers = std::atoi(argv[6]);
is_pp_shared = std::atoi(argv[6]);
}

if (argc >= 8) {
n_pp = parse_list(argv[7]);
n_gpu_layers = std::atoi(argv[7]);
}

if (argc >= 9) {
n_tg = parse_list(argv[8]);
n_pp = parse_list(argv[8]);
}

if (argc >= 10) {
n_pl = parse_list(argv[9]);
n_tg = parse_list(argv[9]);
}

if (argc >= 11) {
n_pl = parse_list(argv[10]);
}

// init LLM
Expand All @@ -108,10 +113,11 @@ int main(int argc, char ** argv) {

llama_context_params ctx_params = llama_context_default_params();

ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = n_batch;
ctx_params.n_ubatch = n_ubatch;
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = n_batch;
ctx_params.n_ubatch = n_ubatch;
ctx_params.flash_attn = flash_attn;

ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
Expand Down Expand Up @@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
Expand Down
30 changes: 27 additions & 3 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ struct cmd_params {
std::vector<llama_split_mode> split_mode;
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
Expand All @@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = {
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
Expand All @@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
Expand Down Expand Up @@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = split<bool>(argv[i], split_delim);
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
} else if (arg == "-fa" || arg == "--flash-attn") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<bool>(argv[i], split_delim);
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
Expand All @@ -498,6 +509,7 @@ struct cmd_params_instance {
llama_split_mode split_mode;
int main_gpu;
bool no_kv_offload;
bool flash_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -532,6 +544,7 @@ struct cmd_params_instance {
cparams.type_k = type_k;
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
cparams.embeddings = embeddings;

return cparams;
Expand All @@ -554,6 +567,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & tk : params.type_k)
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
Expand All @@ -572,6 +586,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .split_mode = */ sm,
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand All @@ -596,6 +611,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .split_mode = */ sm,
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -633,6 +649,7 @@ struct test {
llama_split_mode split_mode;
int main_gpu;
bool no_kv_offload;
bool flash_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand All @@ -657,6 +674,7 @@ struct test {
split_mode = inst.split_mode;
main_gpu = inst.main_gpu;
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
Expand Down Expand Up @@ -731,7 +749,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload",
"main_gpu", "no_kv_offload", "flash_attn",
"tensor_split", "use_mmap", "embeddings",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
Expand All @@ -753,7 +771,7 @@ struct test {
}
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
field == "use_mmap" || field == "embeddings") {
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
Expand Down Expand Up @@ -787,7 +805,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
Expand Down Expand Up @@ -955,6 +973,9 @@ struct markdown_printer : public printer {
if (field == "no_kv_offload") {
return "nkvo";
}
if (field == "flash_attn") {
return "fa";
}
if (field == "use_mmap") {
return "mmap";
}
Expand Down Expand Up @@ -1001,6 +1022,9 @@ struct markdown_printer : public printer {
if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
fields.emplace_back("no_kv_offload");
}
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
fields.emplace_back("flash_attn");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
Expand Down
1 change: 1 addition & 0 deletions examples/server/bench/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def start_server_background(args):
server_args.extend(['--defrag-thold', "0.1"])
server_args.append('--cont-batching')
server_args.append('--metrics')
server_args.append('--flash-attn')
server_args.extend(['--log-format', "text"])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
Expand Down
3 changes: 3 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n");
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
Expand Down Expand Up @@ -2742,6 +2743,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
params.embedding = true;
} else if (arg == "-cb" || arg == "--cont-batching") {
params.cont_batching = true;
} else if (arg == "-fa" || arg == "--flash-attn") {
params.flash_attn = true;
slaren marked this conversation as resolved.
Show resolved Hide resolved
} else if (arg == "-np" || arg == "--parallel") {
if (++i >= argc) {
invalid_param = true;
Expand Down
6 changes: 6 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
Expand Down Expand Up @@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].nsm = prop.multiProcessorCount;
}

for (int id = 0; id < info.device_count; ++id) {
Expand Down Expand Up @@ -2290,6 +2292,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -2564,6 +2569,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
return true;
default:
return false;
Expand Down
Loading
Loading