Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port of self extension to server #5104

Merged
merged 18 commits into from
Jan 27, 2024
Merged

Conversation

Maximilian-Winter
Copy link
Contributor

@Maximilian-Winter Maximilian-Winter commented Jan 24, 2024

Hi, I ported the code for self extension over to the server. I have tested it with a information retrieval, I inserted information out of context into a ~6500 tokens long text and it worked, at least with one slot, I tested multiple request, one after the other, and it gives the same result or similar to main.( I took a random seed.) I'm not sure if anything is correct for use with multiple slots, because I can't really test this on my machine.

I tested with solar-10.7b-instruct-v1.0.Q5_K_M.gguf (4096 trained context) and settings for -c 16384 and --grp-attn-n 4 --grp-attn-w 2048

@Maximilian-Winter
Copy link
Contributor Author

@ggerganov I doesn't had the time to refactor the code to build a llama API for it, like you mentioned in the issue. But I can do this.

@duykhanhbk
Copy link

Thanks @Maximilian-Winter for your work. I'll check it.

@Maximilian-Winter
Copy link
Contributor Author

Maximilian-Winter commented Jan 24, 2024

Have found a problem with cache prompt even when self extend isn't enabled. Will fix it asap.

@x4080
Copy link

x4080 commented Jan 24, 2024

@Maximilian-Winter cool

@duykhanhbk
Copy link

I have found problem with KV cache @Maximilian-Winter
image

@vassioc
Copy link

vassioc commented Jan 24, 2024

server with self-extend and with an example of prompt-cache is a use-case for RAG with no need for semantic search and/or vector store

@ggerganov ggerganov self-requested a review January 25, 2024 19:59
@ggerganov
Copy link
Owner

Have found a problem with cache prompt even when self extend isn't enabled. Will fix it asap.

What is the status of this?

@Maximilian-Winter
Copy link
Contributor Author

@ggerganov Will fix this today, sorry for the delay.

@Maximilian-Winter
Copy link
Contributor Author

@ggerganov Prompt caching should work as before without self extend.

@K-Mistele
Copy link
Contributor

K-Mistele commented Jan 26, 2024

Before this gets merged can you update the server README / docs?
https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md

[Update] it would also be good to add lines in server_print_usage here that indicate (a) that the flag exists and (b) how to use it

@Maximilian-Winter
Copy link
Contributor Author

@K-Mistele Added descriptions to readme and server print usage, but I'm not sure if my descriptions are totally correct.

@Maximilian-Winter
Copy link
Contributor Author

Maximilian-Winter commented Jan 27, 2024

Ok I broke something in self extend. Will fix this now. Sorry.

@Maximilian-Winter
Copy link
Contributor Author

Maximilian-Winter commented Jan 27, 2024

@ggerganov I double checked everything and now even prompt caching works with self extend enabled. Maybe you can take a look at my last commit, that added prompt caching and tell me if it is as you intended.

@Maximilian-Winter
Copy link
Contributor Author

@K-Mistele I updated the descriptions to make them easier to read and like the other parameters.

@ggerganov
Copy link
Owner

You need to fill the context up. It is 32768 and you pass just 6852 tokens generating 17 new tokens.

One way to test this is to set for example -ctx 512, then pass a prompt of ~500 tokens and let it generate some tokens. When it generates 12 tokens, it will fill up the context and it should apply the context shift to free space for the next tokens

@Maximilian-Winter
Copy link
Contributor Author

Ok, I thought doing it like this is enough: start llama.cpp/server -m neural-chat-7b-v3-3.Q8_0.gguf -c 0 -ngl 33 -b 1024 -t 8

@Maximilian-Winter
Copy link
Contributor Author

@ggerganov Can you tell me if that is the correct way doing this? Because I'm not able to trigger the error.

start llama.cpp/server -m neural-chat-7b-v3-3.Q8_0.gguf -c 6860 -ngl 33 -b 1024 -t 8

Available slots:
 -> Slot 0 - max context: 6860
{"timestamp":1706432998,"level":"INFO","function":"main","line":2546,"message":"model loaded"}
all slots are idle and system prompt is empty, clear the KV cache
slot 0 is processing [task id: 0]
slot 0 : in cache: 0 tokens | to process: 6852 tokens
slot 0 : kv cache rm - [0, end)
slot 0: context shift - n_keep = 0, n_left = 6858, n_discard = 3429

print_timings: prompt eval time =    6478.14 ms /  6852 tokens (    0.95 ms per token,  1057.71 tokens per second)
print_timings:        eval time =     726.26 ms /    17 runs   (   42.72 ms per token,    23.41 tokens per second)
print_timings:       total time =    7204.41 ms
slot 0 released (3440 tokens in cache)
{"timestamp":1706433011,"level":"INFO","function":"log_server_request","line":2366,"message":"request","remote_addr":"127.0.0.1","remote_port":64254,"status":200,"method":"POST","path":"/completion","params":{}}

@Green-Sky
Copy link
Collaborator

@Maximilian-Winter the goal it so process a larger prompt than the value specified by -c. so i suggest using something small(er) like 1024.
btw, i suggest a smaller faster, still good model for testing like stablelm-zephyr-3b

brb git bisecting and making sure its actually this pr

@Green-Sky
Copy link
Collaborator

Green-Sky commented Jan 28, 2024

with -c 512:

slot 0 is processing [task id: 55]
slot 0 : in cache: 493 tokens | to process: 1 tokens
slot 0 : kv cache rm - [493, end)
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 256
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 128
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 64
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 32
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 16
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 8
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 4
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 2
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 1
update_slots : failed to decode the batch, n_batch = 1, ret = 1

(it hangs here)

this is what a working context shift looks like:

slot 0 is processing [task id: 1]
slot 0 : in cache: 493 tokens | to process: 1 tokens
slot 0 : kv cache rm - [493, end)
slot 0: context shift - n_keep = 0, n_left = 510, n_discard = 255

@Green-Sky
Copy link
Collaborator

Green-Sky commented Jan 28, 2024

ok, sorry for accusing you. it is in fact not this pr!

a1d6df129bcd3d42cda38c09217d8d4ec4ea3bdd is the first bad commit
commit a1d6df129bcd3d42cda38c09217d8d4ec4ea3bdd
Author: 0cc4m <picard12@live.de>
Date:   Fri Jan 26 23:07:32 2024 +0100

    Add OpenCL add kernel (#5151)

    * Add OpenCL add kernel

    * Put add kernel into different string to stay within MSVC string length limit, disable float16 support due to bad results

 ggml-opencl.cpp | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 ggml-opencl.h   |  1 +
 ggml.c          | 11 ++++++++
 3 files changed, 96 insertions(+), 3 deletions(-)

a1d6df1 is the first bad commit

which makes no sense o.o

@Green-Sky
Copy link
Collaborator

I am restarting the bisect and extend the testing period

@Maximilian-Winter
Copy link
Contributor Author

Maximilian-Winter commented Jan 28, 2024

@Green-Sky I only recently started using the llama.cpp server and not the llama-cpp-python bindings. So for me most things are relatively new. I just saw that self extend was added to main looked at the code and transferred the same logic to the server. Because it looked like an easy change. (And it was relatively easy after understanding the server processing pipeline.)
Let me know if you find the cause.

@Maximilian-Winter
Copy link
Contributor Author

@Green-Sky Could you pinpoint some issue?

@Green-Sky
Copy link
Collaborator

Green-Sky commented Jan 28, 2024

@Maximilian-Winter sorry for the delay.
I am still looking around, but the issue is definitely not caused by this pr. Since the tool/bot that I use to test llama server uses the /health endpoint, i did not check back further, but even that far back it is broken.
However, the reason I thought the bug was maybe caused by this commit, is because the bug actually changes how it shows itself form a silent hang to a log spam infinite loop spam hang.

The bug:
After at least one successful cache context shift, it hangs (on the next one?) with the output:

update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 256
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 128
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 64
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 32
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 16
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 8
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 4
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 2
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 1
update_slots : failed to decode the batch, n_batch = 1, ret = 1

which leads me to believe that the context shift leaks context cache somehow.

I am currently looking into which commit causes the changed bug behavior, and after that i will skip the /health check and go back further and figure out if the server was ever functional.

@Green-Sky
Copy link
Collaborator

ok, turns out I totally missed the 2 existing issues
#4989
#4185

@ggerganov
Copy link
Owner

Ok thanks. I've noticed these issues, but I haven't looked yet into what is the root cause

@Maximilian-Winter
Copy link
Contributor Author

@Green-Sky Thanks for checking that, I thought all day I made a crucial mistake when implementing self extend. If I can help with anything, let me know!

@Green-Sky
Copy link
Collaborator

Green-Sky commented Jan 28, 2024

I stand corrected, looks like the "at least one successful shift" is not necessary, just more likely in practice.

Details
$ result/bin/llama-server -m models/zephyr-quiklang-3b-4k.Q8_0.gguf -ngl 99 -c 512
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5, VMM: yes
{"timestamp":1706462746,"level":"INFO","function":"main","line":2870,"message":"build info","build":0,"commit":"unknown"}
{"timestamp":1706462746,"level":"INFO","function":"main","line":2873,"message":"system info","n_threads":12,"n_threads_batch":-1,"total_threads":24,"system_info":"AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | "}

llama server listening at http://127.0.0.1:8080

{"timestamp":1706462746,"level":"INFO","function":"main","line":2977,"message":"HTTP server listening","port":"8080","hostname":"127.0.0.1"}
llama_model_loader: loaded meta data with 21 key-value pairs and 356 tensors from models/zephyr-quiklang-3b-4k.Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = stablelm
llama_model_loader: - kv   1:                               general.name str              = source
llama_model_loader: - kv   2:                    stablelm.context_length u32              = 4096
llama_model_loader: - kv   3:                  stablelm.embedding_length u32              = 2560
llama_model_loader: - kv   4:                       stablelm.block_count u32              = 32
llama_model_loader: - kv   5:               stablelm.feed_forward_length u32              = 6912
llama_model_loader: - kv   6:              stablelm.rope.dimension_count u32              = 20
llama_model_loader: - kv   7:              stablelm.attention.head_count u32              = 32
llama_model_loader: - kv   8:             stablelm.use_parallel_residual bool             = true
llama_model_loader: - kv   9:      stablelm.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  11:                      tokenizer.ggml.tokens arr[str,50304]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  12:                  tokenizer.ggml.token_type arr[i32,50304]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  13:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  14:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  15:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  16:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  17:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  18:                    tokenizer.chat_template str              = {% for message in messages %}\n{% if m...
llama_model_loader: - kv  19:               general.quantization_version u32              = 2
llama_model_loader: - kv  20:                          general.file_type u32              = 7
llama_model_loader: - type  f32:  130 tensors
llama_model_loader: - type q8_0:  226 tensors
llm_load_vocab: mismatch in special tokens definition ( 31/50304 vs 52/50304 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = stablelm
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 50304
llm_load_print_meta: n_merges         = 50009
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 2560
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 20
llm_load_print_meta: n_embd_head_k    = 80
llm_load_print_meta: n_embd_head_v    = 80
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 2560
llm_load_print_meta: n_embd_v_gqa     = 2560
llm_load_print_meta: f_norm_eps       = 1.0e-05
llm_load_print_meta: f_norm_rms_eps   = 0.0e+00
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 6912
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 3B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 2.80 B
llm_load_print_meta: model size       = 2.77 GiB (8.50 BPW)
llm_load_print_meta: general.name     = source
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: PAD token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_tensors: ggml ctx size =    0.27 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =   130.49 MiB
llm_load_tensors:      CUDA0 buffer size =  2703.02 MiB
.............................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   160.00 MiB
llama_new_context_with_model: KV self size  =  160.00 MiB, K (f16):   80.00 MiB, V (f16):   80.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =     6.01 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   108.25 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     5.00 MiB
llama_new_context_with_model: graph splits (measure): 3
Available slots:
 -> Slot 0 - max context: 512
{"timestamp":1706462747,"level":"INFO","function":"main","line":2998,"message":"model loaded"}
all slots are idle and system prompt is empty, clear the KV cache
{"timestamp":1706462782,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":34342,"status":200,"method":"GET","path":"/health","params":{}}
slot 0 is processing [task id: 0]
slot 0 : in cache: 0 tokens | to process: 492 tokens
slot 0 : kv cache rm - [0, end)

print_timings: prompt eval time =     232.47 ms /   492 tokens (    0.47 ms per token,  2116.38 tokens per second)
print_timings:        eval time =      55.13 ms /     3 runs   (   18.38 ms per token,    54.41 tokens per second)
print_timings:       total time =     287.60 ms
slot 0 released (495 tokens in cache)
{"timestamp":1706462782,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":34352,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706462802,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":55288,"status":200,"method":"GET","path":"/health","params":{}}
slot 0 is processing [task id: 1]
slot 0 : in cache: 492 tokens | to process: 0 tokens
slot 0 : kv cache rm - [492, end)
slot 0 : we have to evaluate at least 1 token to generate logits

print_timings: prompt eval time =      62.54 ms /     0 tokens (     inf ms per token,     0.00 tokens per second)
print_timings:        eval time =      55.80 ms /     3 runs   (   18.60 ms per token,    53.77 tokens per second)
print_timings:       total time =     118.34 ms
slot 0 released (495 tokens in cache)
{"timestamp":1706462802,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":55296,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706462829,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":45830,"status":200,"method":"GET","path":"/health","params":{}}
slot 0 is processing [task id: 2]
slot 0 : in cache: 492 tokens | to process: 0 tokens
slot 0 : kv cache rm - [492, end)
slot 0 : we have to evaluate at least 1 token to generate logits

print_timings: prompt eval time =      36.21 ms /     0 tokens (     inf ms per token,     0.00 tokens per second)
print_timings:        eval time =      55.48 ms /     3 runs   (   18.49 ms per token,    54.07 tokens per second)
print_timings:       total time =      91.69 ms
slot 0 released (495 tokens in cache)
{"timestamp":1706462829,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":45836,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706462857,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":47518,"status":200,"method":"GET","path":"/health","params":{}}
slot 0 is processing [task id: 3]
slot 0 : in cache: 492 tokens | to process: 0 tokens
slot 0 : kv cache rm - [492, end)
slot 0 : we have to evaluate at least 1 token to generate logits

print_timings: prompt eval time =      32.91 ms /     0 tokens (     inf ms per token,     0.00 tokens per second)
print_timings:        eval time =      55.60 ms /     3 runs   (   18.53 ms per token,    53.96 tokens per second)
print_timings:       total time =      88.51 ms
slot 0 released (495 tokens in cache)
{"timestamp":1706462857,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":47526,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706462896,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":49974,"status":200,"method":"GET","path":"/health","params":{}}
slot 0 is processing [task id: 4]
slot 0 : in cache: 492 tokens | to process: 0 tokens
slot 0 : kv cache rm - [492, end)
slot 0 : we have to evaluate at least 1 token to generate logits

print_timings: prompt eval time =      61.16 ms /     0 tokens (     inf ms per token,     0.00 tokens per second)
print_timings:        eval time =      54.84 ms /     3 runs   (   18.28 ms per token,    54.70 tokens per second)
print_timings:       total time =     116.00 ms
slot 0 released (495 tokens in cache)
{"timestamp":1706462896,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":49982,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706462927,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":59888,"status":200,"method":"GET","path":"/health","params":{}}
slot 0 is processing [task id: 5]
slot 0 : in cache: 492 tokens | to process: 0 tokens
slot 0 : kv cache rm - [492, end)
slot 0 : we have to evaluate at least 1 token to generate logits

print_timings: prompt eval time =      57.85 ms /     0 tokens (     inf ms per token,     0.00 tokens per second)
print_timings:        eval time =      69.49 ms /     4 runs   (   17.37 ms per token,    57.56 tokens per second)
print_timings:       total time =     127.35 ms
slot 0 released (496 tokens in cache)
{"timestamp":1706462927,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":59892,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706462927,"level":"INFO","function":"log_server_request","line":2811,"message":"request","remote_addr":"127.0.0.1","remote_port":54580,"status":200,"method":"GET","path":"/health","params":{}}
slot 0 is processing [task id: 6]
slot 0 : in cache: 495 tokens | to process: 1 tokens
slot 0 : kv cache rm - [495, end)
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 256
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 128
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 64
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 32
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 16
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 8
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 4
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 2
update_slots : failed to find free space in the KV cache, retrying with smaller n_batch = 1
update_slots : failed to decode the batch, n_batch = 1, ret = 1

@Green-Sky
Copy link
Collaborator

Green-Sky commented Jan 28, 2024

ok, to no surprise 48c857a #5065 is the first commit with the newer spamy-er hang version.

@K-Mistele
Copy link
Contributor

@Maximilian-Winter It is actually a regression. When -ctx 0, the context size (this is legacy name, should be called "KV cache size") is set equal to the training context size of the mode:

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;

When the KV cache becomes full, there is logic in server.cpp to apply "context shift" which evicts the oldest tokens from the cache in order to free room for new tokens:

if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{
// Shift context
const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2;
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
slot.n_past -= n_discard;
slot.truncated = true;
LOG_VERBOSE("context shift", {
{ "n_ctx", n_ctx },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
});
}

Somehow this functionality seems to have been affected likely by the changes in this PR

@Green-Sky If you could pin point the change that leads to the regression would be helpful. I'm looking at the code in this PR and cannot see how it would affect this, so maybe it was caused by earlier changes

Okay, so this option does not do what I thought it did then. Where can I learn more about how this works and what appropriate values of it would be? I have just been setting it to the model's training context size.

@Maximilian-Winter
Copy link
Contributor Author

Maximilian-Winter commented Jan 28, 2024

@K-Mistele Allow me to cite ggerganov:
First, you set -c to the context that you want to achieve - let's say -c 8192.

Next, given that the original training context of the model is T (let's assume T = 2048), you want to set G >= 8192 / T, so in this case: --grp-attn-n 4 or --grp-attn-n 8.

The --grp-attn-w corresponds to W from the paper. I think the authors generally used 512, but I think you can go up to T/2 - so in this case --grp-attn-w 1024.

Additionally, G has to be multiple of W

Hope this helps.

@K-Mistele
Copy link
Contributor

@K-Mistele Allow me to cite ggerganov: First, you set -c to the context that you want to achieve - let's say -c 8192.

Next, given that the original training context of the model is T (let's assume T = 2048), you want to set G >= 8192 / T, so in this case: --grp-attn-n 4 or --grp-attn-n 8.

The --grp-attn-w corresponds to W from the paper. I think the authors generally used 512, but I think you can go up to T/2 - so in this case --grp-attn-w 1024.

Additionally, G has to be multiple of W

Hope this helps.

This is fantastic, thank you! very helpful.

@x4080
Copy link

x4080 commented Jan 29, 2024

Hi, i have anecdotal how good is this group attention extension is, Using regular model with large context (built in context is 32k), text is about 3000 tokens maybe, it will deteriorate when creating summary, but using self extension, it can do it. So this is a real breakthrough for open source community

@Green-Sky
Copy link
Collaborator

I did more digging and it appears the first commit that should work with context shift appears to be working. 57dd55e

llama_print_timings:        load time =     707.10 ms
llama_print_timings:      sample time =      39.58 ms /     5 runs   (    7.92 ms per token,   126.32 tokens per second)
llama_print_timings: prompt eval time =     543.22 ms /   276 tokens (    1.97 ms per token,   508.08 tokens per second)
llama_print_timings:        eval time =      81.25 ms /     4 runs   (   20.31 ms per token,    49.23 tokens per second)
llama_print_timings:       total time =     675.54 ms
{"timestamp":1706529212,"level":"INFO","function":"log_server_request","line":1266,"message":"request","remote_addr":"127.0.0.1","remote_port":57614,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706529217,"level":"INFO","function":"nextToken","line":518,"message":"input truncated","n_ctx":512,"n_keep":0,"n_left":510}

llama_print_timings:        load time =     707.10 ms
llama_print_timings:      sample time =      11.79 ms /   338 runs   (    0.03 ms per token, 28680.53 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_print_timings:        eval time =    7378.22 ms /   338 runs   (   21.83 ms per token,    45.81 tokens per second)
llama_print_timings:       total time =    7477.77 ms
{"timestamp":1706529220,"level":"INFO","function":"log_server_request","line":1266,"message":"request","remote_addr":"127.0.0.1","remote_port":57618,"status":200,"method":"POST","path":"/completion","params":{}}
root ::= [I] [n] [s] [t] [r] [u] [c] [t] [B] [o] [t] | [p] [o] [t] [a] [t] [o]

llama_print_timings:        load time =     707.10 ms
llama_print_timings:      sample time =      40.12 ms /     5 runs   (    8.02 ms per token,   124.64 tokens per second)
llama_print_timings: prompt eval time =     635.53 ms /   363 tokens (    1.75 ms per token,   571.18 tokens per second)
llama_print_timings:        eval time =      86.38 ms /     4 runs   (   21.59 ms per token,    46.31 tokens per second)
llama_print_timings:       total time =     773.30 ms
{"timestamp":1706529235,"level":"INFO","function":"log_server_request","line":1266,"message":"request","remote_addr":"127.0.0.1","remote_port":38094,"status":200,"method":"POST","path":"/completion","params":{}}
{"timestamp":1706529239,"level":"INFO","function":"nextToken","line":518,"message":"input truncated","n_ctx":512,"n_keep":0,"n_left":510}

(modified log level of "input truncated")

I am not 100% the behavior is correct/as expected. It is kinda hard to tell.

@ggerganov
Copy link
Owner

Could you give me some sample curl commands to reproduce the issue?

@ggerganov
Copy link
Owner

@Maximilian-Winter @Green-Sky Please take a look at #5195 and see if the context shift issues are resolved and if self-extend still functions as expected

@K-Mistele
Copy link
Contributor

K-Mistele commented Jan 31, 2024

What happens if I use a 4096-context model that has Rope scaling built-in and I use -c to set a longer context in combination with the self-extend flags - are both Rope scaling and self-extend applied, or just self-extend? The Bloke's quants on huggingface seem to imply that if you use -c when the models' Rope params are included in the model file, you don't need to use the CLI flags for rope since it will be automatically applied.

For example, using a command like ❯ ./server --model /Users/kyle/Documents/AI/models/optimal/nous-hermes-2-solar-10.7b.Q8_0.gguf --ctx-size 32768 --grp-attn-n 8 --grp-attn-w 1024 --n-gpu-layers -1 --port 8002 --host 0.0.0.0 using https://huggingface.co/TheBloke/Nous-Hermes-2-SOLAR-10.7B-GGUF (model contains rope params)

The model card says,

change -c 4096 to the desired sequence length. For extended sequence models - eg 8K, 16K, 32K - the necessary RoPE scaling parameters are read from the GGUF file and set by llama.cpp automatically. Note that longer sequence lengths require much more resources, so you may need to reduce this value.

Does this mean both self extend and Rope scaling will be applied? What would the expected outcome be here? It seems like if I explicitly enable self-extend, then Rope should be disabled.

@cvhoang
Copy link

cvhoang commented Feb 1, 2024

@Maximilian-Winter it looks like logprobs returned by the server using self extension are different (smaller). Is there a way to recalibrate logprobs so that they are similar to ones returned by server without self extension?

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Feb 3, 2024
* Ported self extension to server example

* Update server.cpp

* Fixed prompt caching without self extend

* Update server.cpp

* Added description to server readme.

* Update server.cpp

* Update server.cpp

* Update server.cpp

* Update server.cpp

* Update README.md

* Changed descriptions

* server : formatting

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update server.cpp

* Update server.cpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* Ported self extension to server example

* Update server.cpp

* Fixed prompt caching without self extend

* Update server.cpp

* Added description to server readme.

* Update server.cpp

* Update server.cpp

* Update server.cpp

* Update server.cpp

* Update README.md

* Changed descriptions

* server : formatting

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update server.cpp

* Update server.cpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@K-Mistele
Copy link
Contributor

The author of the self-extend paper dropped what he considers to be a better version of the empirical formula in my twitter replies, would it be possible to update the implementation?

https://x.com/serendip410/status/1782957763997401553

$(\frac{1}{2} \sim \frac{2}{3}) \times L &gt; W + \frac{N-W}{G}$

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants