Skip to content

Commit c03ce22

Browse files
committed
fix(hybrid): implement N-1 checkpointing to support 1-token rollbacks
Forces an N-1 state snapshot during prompt prefilling for hybrid models. This ensures the engine can safely perform a 1-token rollback to refresh logits upon 100% cache matches (e.g., changing seeds on identical prompts), preventing RNN state desyncs and empty outputs.
1 parent b342f70 commit c03ce22

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

llama_cpp/llama.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,27 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13061306
try:
13071307
while True:
13081308
if len(tokens) > 0:
1309-
self.eval(tokens)
1309+
# For hybrid models processing a prompt (len > 1), force an N-1 checkpoint
1310+
# to safely allow 1-token rollbacks (e.g., for seed changes on 100% prompt matches).
1311+
if self.is_hybrid and self._hybrid_cache_mgr is not None and len(tokens) > 1:
1312+
body_tokens = tokens[:-1]
1313+
last_token = [tokens[-1]]
1314+
1315+
# 1. Evaluate up to N-1
1316+
self.eval(body_tokens)
1317+
1318+
# 2. Save the N-1 state snapshot
1319+
current_history = self._input_ids[:self.n_tokens].tolist()
1320+
self._hybrid_cache_mgr.save_checkpoint(
1321+
current_pos=self.n_tokens,
1322+
tokens=current_history,
1323+
seq_id=0
1324+
)
1325+
# 3. Evaluate the final token to refresh logits
1326+
self.eval(last_token)
1327+
else:
1328+
# Standard evaluation or single-token generation step
1329+
self.eval(tokens)
13101330
while sample_idx < self.n_tokens:
13111331
token = self._sampling_ctx.sample(self._ctx, idx=-1)
13121332
self._sampling_ctx.accept(token, False if grammar is None else True)

0 commit comments

Comments
 (0)