Skip to content

Commit 0acb224

Browse files
committed
feat: add HybridCheckpointCache detect support for recurrent/hybrid/SWA models
- Introduce ctx_checkpoints parameter (default 16) - Detect recurrent / hybrid / n_swa > 0 models in __init__ - Automatically use HybridCheckpointCache when hybrid architecture is detected - Properly close and clear HybridCheckpointCache in __del__ Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent 29b9522 commit 0acb224

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

llama_cpp/llama.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@
3434
from .llama_grammar import LlamaGrammar
3535
from .llama_cache import (
3636
BaseLlamaCache,
37-
LlamaCache, # type: ignore
38-
LlamaDiskCache, # type: ignore
39-
LlamaRAMCache, # type: ignore
40-
LlamaTrieCache, # type: ignore
37+
LlamaCache, # type: ignore
38+
LlamaDiskCache, # type: ignore
39+
LlamaRAMCache, # type: ignore
40+
LlamaTrieCache, # type: ignore
41+
HybridCheckpointCache, # type: ignore
4142
)
4243
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
4344
import llama_cpp.llama_cpp as llama_cpp
@@ -109,6 +110,8 @@ def __init__(
109110
op_offload: Optional[bool] = None,
110111
swa_full: Optional[bool] = None,
111112
kv_unified: Optional[bool] = None,
113+
# HybridCheckpointCache Params
114+
ctx_checkpoints: int = 16,
112115
# Sampling Params
113116
last_n_tokens_size: int = 64,
114117
# LoRA Params
@@ -197,6 +200,7 @@ def __init__(
197200
op_offload: whether to offload host tensor operations to device
198201
swa_full: whether to use full-size SWA cache
199202
kv_unified: use single unified KV buffer for the KV cache of all sequences
203+
ctx_checkpoints: max number of context checkpoints to create per slot (default: 16)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
200204
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
201205
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
202206
lora_path: Path to a LoRA file to apply to the model.
@@ -466,6 +470,26 @@ def __init__(
466470
)
467471
)
468472

473+
# Hybrid architecture detection
474+
_is_recurrent = self._model.is_recurrent()
475+
_is_hybrid = self._model.is_hybrid()
476+
_n_swa = self._model.n_swa()
477+
# checkpoints are created only if:
478+
# - the model uses SWA and we are not using `swa_full`
479+
# - the model architecture is marked as recurrent or hybrid
480+
self.is_hybrid = _is_recurrent or _is_hybrid or (_n_swa > 0 and not swa_full)
481+
482+
if self.is_hybrid:
483+
if self.verbose:
484+
print(f"Llama.__init__: Hybrid/Recurrent model detected."
485+
f"(is_recurrent: {_is_recurrent}, is_hybrid: {_is_hybrid}, n_swa: {_n_swa}), swa_full: {swa_full}. "
486+
f" Enabling HybridCheckpointCache(ctx_checkpoints={ctx_checkpoints}).",
487+
file=sys.stderr)
488+
self.ctx_checkpoints = ctx_checkpoints
489+
self._hybrid_cache_mgr = HybridCheckpointCache(self._ctx.ctx, max_checkpoints=self.ctx_checkpoints, verbose=self.verbose)
490+
else:
491+
self._hybrid_cache_mgr = None
492+
469493
self._batch = self._stack.enter_context(
470494
contextlib.closing(
471495
internals.LlamaBatch(
@@ -634,13 +658,18 @@ def close(self) -> None:
634658
self._candidates.close()
635659
self._candidates = None
636660

661+
if getattr(self, "_hybrid_cache_mgr", None) is not None and hasattr(self._hybrid_cache_mgr, "close"):
662+
self._hybrid_cache_mgr.close()
663+
self._hybrid_cache_mgr = None
664+
637665
if hasattr(self, "chat_handler") and hasattr(self.chat_handler, "close"):
638666
self.chat_handler.close()
639667

640668
self.model_params =None
641669
self.context_params = None
642670
self.chat_handler = None
643671
self.input_ids = None
672+
self.metadata = None
644673
self.scores = None
645674
self.tokenizer_ = None
646675

@@ -1099,6 +1128,8 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
10991128
# No prefix matched. Completely clear the KV cache to prevent context poisoning.
11001129
self.n_tokens = 0
11011130
self._ctx.memory_clear(True)
1131+
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1132+
self._hybrid_cache_mgr.clear()
11021133

11031134
# Reset the model state
11041135
if reset:

0 commit comments

Comments
 (0)