Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for new KV Cache Offloading API #995

Closed
abetlen opened this issue Dec 11, 2023 · 2 comments
Closed

Add support for new KV Cache Offloading API #995

abetlen opened this issue Dec 11, 2023 · 2 comments

Comments

@abetlen
Copy link
Owner

abetlen commented Dec 11, 2023

Source ggerganov/llama.cpp#4309

@Ph0rk0z
Copy link

Ph0rk0z commented Dec 14, 2023

I force on the ctx param to offload kv_cache. The logs print that it's using F16 cache. 81 layers offloaded. memory sizes for K and V. assume KV "should" be. Speeds on 70b are now half and prompt processing is less than 1/4 with no context. Womp Womp.

@Ph0rk0z
Copy link

Ph0rk0z commented Dec 14, 2023

Ok, I got it working when I moved the kv to the proper place in the struct.

    _fields_ = [
        ("seed", c_uint32),
        ("n_ctx", c_uint32),
        ("n_batch", c_uint32),
        ("n_threads", c_uint32),
        ("n_threads_batch", c_uint32),
        ("rope_scaling_type", c_int8),
        ("rope_freq_base", c_float),
        ("rope_freq_scale", c_float),
        ("yarn_ext_factor", c_float),
        ("yarn_attn_factor", c_float),
        ("yarn_beta_fast", c_float),
        ("yarn_beta_slow", c_float),
        ("yarn_orig_ctx", c_uint32),
        ("type_k", c_int),
        ("type_v", c_int),
        ("mul_mat_q", c_bool),
        ("logits_all", c_bool),
        ("embedding", c_bool),
        ("offload_kqv", c_bool),
    ]

brandonrobertz added a commit to brandonrobertz/llama-cpp-python that referenced this issue Dec 17, 2023
This addresses two issues:

 - abetlen#995 which just requests to add the KV cache offloading param
 - abetlen#1006 a NULL ptr exception when using the embeddings (introduced by
   leaving f16_kv in the fields struct)
brandonrobertz added a commit to brandonrobertz/llama-cpp-python that referenced this issue Dec 17, 2023
This addresses two issues:

 - abetlen#995 which just requests to add the KV cache offloading param
 - abetlen#1006 a NULL ptr exception when using the embeddings (introduced by
   leaving f16_kv in the fields struct)
brandonrobertz added a commit to brandonrobertz/llama-cpp-python that referenced this issue Dec 17, 2023
F16_KV appears to have been removed here: ggerganov/llama.cpp@af99c6f

This addresses two issues:

 - abetlen#995 which just requests to add the KV cache offloading param
 - abetlen#1006 a NULL ptr exception when using the embeddings (introduced by
   leaving f16_kv in the fields struct)
abetlen pushed a commit that referenced this issue Dec 18, 2023
F16_KV appears to have been removed here: ggerganov/llama.cpp@af99c6f

This addresses two issues:

 - #995 which just requests to add the KV cache offloading param
 - #1006 a NULL ptr exception when using the embeddings (introduced by
   leaving f16_kv in the fields struct)
@abetlen abetlen closed this as completed Dec 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants