Skip to content

Commit f37ebcf

Browse files
committed
add tiktoken to eval
1 parent 9a4c1f6 commit f37ebcf

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@
66

77

88
import argparse
9-
from typing import Optional
9+
10+
from typing import Optional, Union
1011

1112
import lm_eval
1213
import torch
1314

15+
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
16+
from executorch.examples.models.llama2.tokenizer.tokenizer import Tokenizer
17+
1418
from lm_eval.api.model import LM
1519
from lm_eval.evaluator import evaluate
1620
from lm_eval.models.huggingface import HFLM as eval_wrapper
1721
from lm_eval.tasks import get_task_dict
18-
from sentencepiece import SentencePieceProcessor
22+
1923
from torch import nn
2024

2125
from .builder import LlamaEdgeManager
@@ -33,7 +37,7 @@ class GPTFastEvalWrapper(eval_wrapper):
3337
def __init__(
3438
self,
3539
model: nn.Module,
36-
tokenizer: SentencePieceProcessor,
40+
tokenizer: Union[Tokenizer, Tiktoken],
3741
max_seq_length: Optional[int] = None,
3842
):
3943
super().__init__()
@@ -46,7 +50,7 @@ def __init__(
4650

4751
@property
4852
def eot_token_id(self):
49-
return self._tokenizer.eos_id()
53+
return self._tokenizer.eos_id
5054

5155
@property
5256
def max_length(self):
@@ -65,7 +69,7 @@ def device(self):
6569
return self._device
6670

6771
def tok_encode(self, string: str, **kwargs):
68-
tokens = [self._tokenizer.bos_id()] + self._tokenizer.encode(string)
72+
tokens = self._tokenizer.encode(string, bos=True, eos=False)
6973
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
7074
# encoded is a pytorch tensor, but some internal logic in the
7175
# eval harness expects it to be a list instead
@@ -93,7 +97,7 @@ class ETEagerEvalWrapper(GPTFastEvalWrapper):
9397
def __init__(
9498
self,
9599
model: str,
96-
tokenizer: SentencePieceProcessor,
100+
tokenizer: Union[Tokenizer, Tiktoken],
97101
max_seq_length: Optional[int] = None,
98102
):
99103
super().__init__(None, tokenizer, max_seq_length)
@@ -120,7 +124,7 @@ class ETRunnerEvalWrapper(GPTFastEvalWrapper):
120124
def __init__(
121125
self,
122126
model: str,
123-
tokenizer: SentencePieceProcessor,
127+
tokenizer: Union[Tokenizer, Tiktoken],
124128
tokenizer_bin: str,
125129
max_seq_length: Optional[int] = None,
126130
):
@@ -183,7 +187,11 @@ def gen_eval_wrapper(
183187
Returns:
184188
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
185189
"""
186-
tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
190+
try:
191+
tokenizer = Tokenizer(model_path=str(args.tokenizer_path))
192+
except Exception:
193+
print("Using Tiktokenizer")
194+
tokenizer = Tiktoken(model_path=str(args.tokenizer_path))
187195

188196
# ExecuTorch Binary Evaluation
189197
if (model := args.pte) is not None:

0 commit comments

Comments
 (0)