Skip to content

Commit 246e3da

Browse files
committed
Fix some crash bug!
1 parent 92178db commit 246e3da

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

llama_cpp/_internals.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def set_batch(self, batch: Sequence[int], n_past: int):
522522
self.batch.pos[i] = n_past + i
523523
self.batch.seq_id[i][0] = 0
524524
self.batch.n_seq_id[i] = 1
525+
self.batch.logits[n_tokens - 1] = True
525526

526527
def add_sequence(self, batch: Sequence[int], seq_id: int):
527528
n_tokens = len(batch)
@@ -533,6 +534,7 @@ def add_sequence(self, batch: Sequence[int], seq_id: int):
533534
self.batch.pos[j] = i
534535
self.batch.seq_id[j][0] = seq_id
535536
self.batch.n_seq_id[j] = 1
537+
self.batch.logits[n_tokens - 1] = True
536538

537539

538540
class LlamaTokenDataArray:
@@ -983,7 +985,7 @@ def get_seed(self) -> int:
983985
assert self.sampler is not None
984986
return llama_cpp.llama_sampler_get_seed(self.sampler)
985987

986-
def sample(self, ctx: LlamaContext, idx: int) -> int:
988+
def sample(self, ctx: LlamaContext, idx: ctypes.c_int32) -> ctypes.c_int32:
987989
assert self.sampler is not None
988990
assert ctx.ctx is not None
989991
return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)

llama_cpp/llama_cpp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,9 +2681,9 @@ def llama_batch_get_one(
26812681
"llama_batch_init", [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32], llama_batch
26822682
)
26832683
def llama_batch_init(
2684-
n_tokens: Union[ctypes.c_int32, int],
2685-
embd: Union[ctypes.c_int32, int],
2686-
n_seq_max: Union[ctypes.c_int32, int],
2684+
n_tokens: ctypes.c_int32,
2685+
embd: ctypes.c_int32,
2686+
n_seq_max: ctypes.c_int32,
26872687
/,
26882688
) -> llama_batch:
26892689
"""Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
@@ -2872,10 +2872,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
28722872
ctypes.POINTER(ctypes.c_float),
28732873
)
28742874
def llama_get_logits_ith(
2875-
ctx: llama_context_p, i: Union[ctypes.c_int32, int], /
2876-
) -> CtypesArray[ctypes.c_float]:
2875+
ctx: llama_context_p, i: ctypes.c_int32, /
2876+
) -> ctypes.POINTER(ctypes.c_float):
28772877
"""Logits for the ith token. Equivalent to:
2878-
llama_get_logits(ctx) + i*n_vocab"""
2878+
llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab"""
28792879
...
28802880

28812881

0 commit comments

Comments
 (0)