|
34 | 34 | from .llama_grammar import LlamaGrammar |
35 | 35 | from .llama_cache import ( |
36 | 36 | BaseLlamaCache, |
37 | | - LlamaCache, # type: ignore |
38 | | - LlamaDiskCache, # type: ignore |
39 | | - LlamaRAMCache, # type: ignore |
40 | | - LlamaTrieCache, # type: ignore |
| 37 | + LlamaCache, # type: ignore |
| 38 | + LlamaDiskCache, # type: ignore |
| 39 | + LlamaRAMCache, # type: ignore |
| 40 | + LlamaTrieCache, # type: ignore |
| 41 | + HybridCheckpointCache, # type: ignore |
41 | 42 | ) |
42 | 43 | from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer |
43 | 44 | import llama_cpp.llama_cpp as llama_cpp |
@@ -109,6 +110,8 @@ def __init__( |
109 | 110 | op_offload: Optional[bool] = None, |
110 | 111 | swa_full: Optional[bool] = None, |
111 | 112 | kv_unified: Optional[bool] = None, |
| 113 | + # HybridCheckpointCache Params |
| 114 | + ctx_checkpoints: int = 16, |
112 | 115 | # Sampling Params |
113 | 116 | last_n_tokens_size: int = 64, |
114 | 117 | # LoRA Params |
@@ -197,6 +200,7 @@ def __init__( |
197 | 200 | op_offload: whether to offload host tensor operations to device |
198 | 201 | swa_full: whether to use full-size SWA cache |
199 | 202 | kv_unified: use single unified KV buffer for the KV cache of all sequences |
| 203 | + ctx_checkpoints: max number of context checkpoints to create per slot (default: 16)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293) |
200 | 204 | last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. |
201 | 205 | lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. |
202 | 206 | lora_path: Path to a LoRA file to apply to the model. |
@@ -466,6 +470,26 @@ def __init__( |
466 | 470 | ) |
467 | 471 | ) |
468 | 472 |
|
| 473 | + # Hybrid architecture detection |
| 474 | + _is_recurrent = self._model.is_recurrent() |
| 475 | + _is_hybrid = self._model.is_hybrid() |
| 476 | + _n_swa = self._model.n_swa() |
| 477 | + # checkpoints are created only if: |
| 478 | + # - the model uses SWA and we are not using `swa_full` |
| 479 | + # - the model architecture is marked as recurrent or hybrid |
| 480 | + self.is_hybrid = _is_recurrent or _is_hybrid or (_n_swa > 0 and not swa_full) |
| 481 | + |
| 482 | + if self.is_hybrid: |
| 483 | + if self.verbose: |
| 484 | + print(f"Llama.__init__: Hybrid/Recurrent model detected." |
| 485 | + f"(is_recurrent: {_is_recurrent}, is_hybrid: {_is_hybrid}, n_swa: {_n_swa}), swa_full: {swa_full}. " |
| 486 | + f" Enabling HybridCheckpointCache(ctx_checkpoints={ctx_checkpoints}).", |
| 487 | + file=sys.stderr) |
| 488 | + self.ctx_checkpoints = ctx_checkpoints |
| 489 | + self._hybrid_cache_mgr = HybridCheckpointCache(self._ctx.ctx, max_checkpoints=self.ctx_checkpoints, verbose=self.verbose) |
| 490 | + else: |
| 491 | + self._hybrid_cache_mgr = None |
| 492 | + |
469 | 493 | self._batch = self._stack.enter_context( |
470 | 494 | contextlib.closing( |
471 | 495 | internals.LlamaBatch( |
@@ -634,13 +658,18 @@ def close(self) -> None: |
634 | 658 | self._candidates.close() |
635 | 659 | self._candidates = None |
636 | 660 |
|
| 661 | + if getattr(self, "_hybrid_cache_mgr", None) is not None and hasattr(self._hybrid_cache_mgr, "close"): |
| 662 | + self._hybrid_cache_mgr.close() |
| 663 | + self._hybrid_cache_mgr = None |
| 664 | + |
637 | 665 | if hasattr(self, "chat_handler") and hasattr(self.chat_handler, "close"): |
638 | 666 | self.chat_handler.close() |
639 | 667 |
|
640 | 668 | self.model_params =None |
641 | 669 | self.context_params = None |
642 | 670 | self.chat_handler = None |
643 | 671 | self.input_ids = None |
| 672 | + self.metadata = None |
644 | 673 | self.scores = None |
645 | 674 | self.tokenizer_ = None |
646 | 675 |
|
@@ -1099,6 +1128,8 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array): |
1099 | 1128 | # No prefix matched. Completely clear the KV cache to prevent context poisoning. |
1100 | 1129 | self.n_tokens = 0 |
1101 | 1130 | self._ctx.memory_clear(True) |
| 1131 | + if self.is_hybrid and self._hybrid_cache_mgr is not None: |
| 1132 | + self._hybrid_cache_mgr.clear() |
1102 | 1133 |
|
1103 | 1134 | # Reset the model state |
1104 | 1135 | if reset: |
|
0 commit comments