Skip to content

Commit 77146b9

Browse files
committed
Replace token function with vocab function
1 parent 7a626ba commit 77146b9

File tree

2 files changed

+67
-34
lines changed

2 files changed

+67
-34
lines changed

llama_cpp/_internals.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,48 +112,69 @@ def get_tensor(self, name: str) -> ctypes.c_void_p:
112112
# Vocab
113113

114114
def token_get_text(self, token: int) -> str:
115-
return llama_cpp.llama_token_get_text(self.vocab, token).decode("utf-8")
115+
return llama_cpp.llama_vocab_get_text(self.vocab, token).decode("utf-8")
116116

117117
def token_get_score(self, token: int) -> float:
118-
return llama_cpp.llama_token_get_score(self.vocab, token)
118+
return llama_cpp.llama_vocab_get_score(self.vocab, token)
119119

120120
def token_get_attr(self, token: int) -> int:
121-
return llama_cpp.llama_token_get_attr(self.vocab, token)
121+
return llama_cpp.llama_vocab_get_attr(self.vocab, token)
122+
123+
def token_is_eog(self, token: int) -> bool:
124+
return llama_cpp.llama_vocab_is_eog(self.vocab, token)
125+
126+
def token_is_control(self, token: int) -> bool:
127+
return llama_cpp.llama_vocab_is_control(self.vocab, token)
122128

123129
# Special tokens
124130

125131
def token_bos(self) -> int:
126-
return llama_cpp.llama_token_bos(self.vocab)
132+
return llama_cpp.llama_vocab_bos(self.vocab)
127133

128134
def token_eos(self) -> int:
129-
return llama_cpp.llama_token_eos(self.vocab)
135+
return llama_cpp.llama_vocab_eos(self.vocab)
130136

131-
def token_cls(self) -> int:
132-
return llama_cpp.llama_token_cls(self.vocab)
137+
def token_eot(self) -> int:
138+
return llama_cpp.llama_vocab_eot(self.vocab)
133139

134140
def token_sep(self) -> int:
135-
return llama_cpp.llama_token_sep(self.vocab)
141+
return llama_cpp.llama_vocab_sep(self.vocab)
136142

137143
def token_nl(self) -> int:
138-
return llama_cpp.llama_token_nl(self.vocab)
144+
return llama_cpp.llama_vocab_nl(self.vocab)
139145

140-
def token_prefix(self) -> int:
141-
raise NotImplementedError("token_prefix is not implemented in llama.cpp")
146+
def token_pad(self) -> int:
147+
return llama_cpp.llama_vocab_pad(self.vocab)
142148

143-
def token_middle(self) -> int:
144-
raise NotImplementedError("token_middle is not implemented in llama.cpp")
149+
def token_cls(self) -> int:
150+
return llama_cpp.llama_vocab_cls(self.vocab)
145151

146-
def token_suffix(self) -> int:
147-
raise NotImplementedError("token_suffix is not implemented in llama.cpp")
152+
def token_fim_pre(self) -> int:
153+
return llama_cpp.llama_vocab_fim_pre(self.vocab)
148154

149-
def token_eot(self) -> int:
150-
return llama_cpp.llama_token_eot(self.vocab)
155+
def token_fim_suf(self) -> int:
156+
return llama_cpp.llama_vocab_fim_suf(self.vocab)
157+
158+
def token_fim_mid(self) -> int:
159+
return llama_cpp.llama_vocab_fim_mid(self.vocab)
160+
161+
def token_fim_pad(self) -> int:
162+
return llama_cpp.llama_vocab_fim_pad(self.vocab)
163+
164+
def token_fim_rep(self) -> int:
165+
return llama_cpp.llama_vocab_fim_rep(self.vocab)
166+
167+
def token_fim_sep(self) -> int:
168+
return llama_cpp.llama_vocab_fim_sep(self.vocab)
169+
170+
def get_add_bos(self) -> bool:
171+
return llama_cpp.llama_vocab_get_add_bos(self.vocab)
151172

152-
def add_bos_token(self) -> bool:
153-
return llama_cpp.llama_add_bos_token(self.vocab)
173+
def get_add_eos(self) -> bool:
174+
return llama_cpp.llama_vocab_get_add_eos(self.vocab)
154175

155-
def add_eos_token(self) -> bool:
156-
return llama_cpp.llama_add_eos_token(self.vocab)
176+
def get_add_sep(self) -> bool:
177+
return llama_cpp.llama_vocab_get_add_sep(self.vocab)
157178

158179
# Tokenization
159180

llama_cpp/llama.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,29 +1189,29 @@ def _create_completion(
11891189

11901190
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
11911191
created: int = int(time.time())
1192-
bos_token_id: int = self.token_bos()
1193-
cls_token_id: int = self._model.token_cls()
1192+
bos_token_id: int = self._model.token_bos()
1193+
eos_token_id: int = self._model.token_eos()
11941194
sep_token_id: int = self._model.token_sep()
1195-
prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix
1196-
middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix
1197-
suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix
1195+
prefix_token_id: int = self._model.token_fim_pre()
1196+
middle_token_id: int = self._model.token_fim_mid()
1197+
suffix_token_id: int = self._model.token_fim_suf()
11981198
add_space_prefix: bool = (
11991199
self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
12001200
)
1201-
bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
1201+
bos_tokens: List[int] = [bos_token_id]
12021202
eos_tokens: List[int] = [
1203-
sep_token_id if sep_token_id != -1 else self.token_eos()
1203+
sep_token_id if self._model.get_add_sep() else eos_token_id
12041204
]
12051205

12061206
if (
12071207
(isinstance(prompt, list) and suffix is None)
1208-
or not self._model.add_bos_token()
1208+
or not self._model.get_add_bos()
12091209
or bos_tokens[:1] == [-1]
12101210
):
12111211
bos_tokens = []
12121212

12131213
if (isinstance(prompt, list) and suffix is None) or (
1214-
not self._model.add_eos_token() and sep_token_id == -1
1214+
not self._model.get_add_eos() and not self._model.get_add_sep()
12151215
):
12161216
eos_tokens = []
12171217

@@ -2294,18 +2294,30 @@ def tokenizer(self) -> LlamaTokenizer:
22942294
"""Return the llama tokenizer for this model."""
22952295
return LlamaTokenizer(self)
22962296

2297+
def token_bos(self) -> int:
2298+
"""Return the beginning-of-sequence token."""
2299+
return self._model.token_bos()
2300+
22972301
def token_eos(self) -> int:
22982302
"""Return the end-of-sequence token."""
22992303
return self._model.token_eos()
23002304

2301-
def token_bos(self) -> int:
2302-
"""Return the beginning-of-sequence token."""
2303-
return self._model.token_bos()
2305+
def token_eot(self) -> int:
2306+
"""Return the end-of-turn token."""
2307+
return self._model.token_eot()
2308+
2309+
def token_sep(self) -> int:
2310+
"""Return the sentence-separator token."""
2311+
return self._model.token_sep()
23042312

23052313
def token_nl(self) -> int:
2306-
"""Return the newline token."""
2314+
"""Return the next-line token."""
23072315
return self._model.token_nl()
23082316

2317+
def token_pad(self) -> int:
2318+
"""Return the padding token."""
2319+
return self._model.token_pad()
2320+
23092321
def pooling_type(self) -> str:
23102322
"""Return the pooling type."""
23112323
return self._ctx.pooling_type()

0 commit comments

Comments
 (0)