Skip to content

Commit 06f64ae

Browse files
committed
refactor(Llama): enhance error handling and cleanup in eval method
- Wrap `decode` in a try-except block to provide detailed error context (position, batch size) on failure. - Capture and log the result of `memory_seq_rm` to assist in debugging KV cache issues. - Add an early return for empty token lists. - Refactor loop variables and state updates for better clarity. - Remove dead code related to logits processing.
1 parent 6d31ab0 commit 06f64ae

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

llama_cpp/llama.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -667,36 +667,44 @@ def eval(self, tokens: Sequence[int]):
667667
Args:
668668
tokens: The list of tokens to evaluate.
669669
"""
670-
self._ctx.memory_seq_rm(0, self.n_tokens, -1)
671-
for i in range(0, len(tokens), self.n_batch):
672-
batch = tokens[i : min(len(tokens), i + self.n_batch)]
670+
if len(tokens) == 0:
671+
return
672+
n_eval = len(tokens)
673+
current_pos = self.n_tokens
674+
675+
if self._ctx:
676+
is_success = self._ctx.memory_seq_rm(0, current_pos, -1)
677+
678+
for i in range(0, n_eval, self.n_batch):
679+
batch = tokens[i : min(n_eval, i + self.n_batch)]
673680
n_past = self.n_tokens
674-
n_tokens = len(batch)
681+
n_batch_tokens = len(batch)
675682
self._batch.set_batch(
676683
batch=batch, n_past=n_past, logits_all=self._logits_all
677684
)
678-
self._ctx.decode(self._batch)
685+
try:
686+
self._ctx.decode(self._batch)
687+
except Exception as e:
688+
raise RuntimeError(
689+
f"Decode Failed at Pos {current_pos}. "
690+
f"Batch size: {n_batch_tokens}. "
691+
f"Result of memory_seq_rm: {is_success}. "
692+
f"Error: {str(e)}."
693+
) from e
679694
# Save tokens
680-
self.input_ids[n_past : n_past + n_tokens] = batch
695+
self.input_ids[n_past : n_past + n_batch_tokens] = batch
681696
# Save logits
682697
if self._logits_all:
683-
rows = n_tokens
698+
rows = n_batch_tokens
684699
cols = self._n_vocab
685700
logits = np.ctypeslib.as_array(
686701
self._ctx.get_logits(), shape=(rows * cols,)
687702
)
688-
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
689-
else:
690-
# rows = 1
691-
# cols = self._n_vocab
692-
# logits = np.ctypeslib.as_array(
693-
# self._ctx.get_logits(), shape=(rows * cols,)
694-
# )
695-
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
696-
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
697-
pass
703+
self.scores[n_past : n_past + n_batch_tokens, :].reshape(-1)[::] = logits
704+
698705
# Update n_tokens
699-
self.n_tokens += n_tokens
706+
current_pos += n_batch_tokens
707+
self.n_tokens = current_pos
700708

701709
def _init_sampler(
702710
self,

0 commit comments

Comments
 (0)