6
6
7
7
8
8
import argparse
9
- from typing import Optional
9
+
10
+ from typing import Optional , Union
10
11
11
12
import lm_eval
12
13
import torch
13
14
15
+ from executorch .examples .models .llama2 .tokenizer .tiktoken import Tokenizer as Tiktoken
16
+ from executorch .examples .models .llama2 .tokenizer .tokenizer import Tokenizer
17
+
14
18
from lm_eval .api .model import LM
15
19
from lm_eval .evaluator import evaluate
16
20
from lm_eval .models .huggingface import HFLM as eval_wrapper
17
21
from lm_eval .tasks import get_task_dict
18
- from sentencepiece import SentencePieceProcessor
22
+
19
23
from torch import nn
20
24
21
25
from .builder import LlamaEdgeManager
@@ -33,7 +37,7 @@ class GPTFastEvalWrapper(eval_wrapper):
33
37
def __init__ (
34
38
self ,
35
39
model : nn .Module ,
36
- tokenizer : SentencePieceProcessor ,
40
+ tokenizer : Union [ Tokenizer , Tiktoken ] ,
37
41
max_seq_length : Optional [int ] = None ,
38
42
):
39
43
super ().__init__ ()
@@ -46,7 +50,7 @@ def __init__(
46
50
47
51
@property
48
52
def eot_token_id (self ):
49
- return self ._tokenizer .eos_id ()
53
+ return self ._tokenizer .eos_id
50
54
51
55
@property
52
56
def max_length (self ):
@@ -65,7 +69,7 @@ def device(self):
65
69
return self ._device
66
70
67
71
def tok_encode (self , string : str , ** kwargs ):
68
- tokens = [ self ._tokenizer .bos_id ()] + self . _tokenizer . encode (string )
72
+ tokens = self ._tokenizer .encode (string , bos = True , eos = False )
69
73
encoded = torch .tensor (tokens , dtype = torch .int , device = self .device )
70
74
# encoded is a pytorch tensor, but some internal logic in the
71
75
# eval harness expects it to be a list instead
@@ -93,7 +97,7 @@ class ETEagerEvalWrapper(GPTFastEvalWrapper):
93
97
def __init__ (
94
98
self ,
95
99
model : str ,
96
- tokenizer : SentencePieceProcessor ,
100
+ tokenizer : Union [ Tokenizer , Tiktoken ] ,
97
101
max_seq_length : Optional [int ] = None ,
98
102
):
99
103
super ().__init__ (None , tokenizer , max_seq_length )
@@ -120,7 +124,7 @@ class ETRunnerEvalWrapper(GPTFastEvalWrapper):
120
124
def __init__ (
121
125
self ,
122
126
model : str ,
123
- tokenizer : SentencePieceProcessor ,
127
+ tokenizer : Union [ Tokenizer , Tiktoken ] ,
124
128
tokenizer_bin : str ,
125
129
max_seq_length : Optional [int ] = None ,
126
130
):
@@ -183,7 +187,11 @@ def gen_eval_wrapper(
183
187
Returns:
184
188
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
185
189
"""
186
- tokenizer = SentencePieceProcessor (model_file = str (args .tokenizer_path ))
190
+ try :
191
+ tokenizer = Tokenizer (model_path = str (args .tokenizer_path ))
192
+ except Exception :
193
+ print ("Using Tiktokenizer" )
194
+ tokenizer = Tiktoken (model_path = str (args .tokenizer_path ))
187
195
188
196
# ExecuTorch Binary Evaluation
189
197
if (model := args .pte ) is not None :
0 commit comments