-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtopical_tokenizers.py
100 lines (76 loc) · 2.86 KB
/
topical_tokenizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import spacy
from transformers import GPT2Tokenizer, GPT2TokenizerFast
from collections import defaultdict
import pickle
from os import path
from gensim.parsing.preprocessing import STOPWORDS
from gensim.utils import simple_preprocess
class Tokenizer:
def __init__(self):
pass
def tokenize(self):
raise NotImplemented()
def encode(self, text):
raise NotImplemented()
def save_dict(self):
raise NotImplemented()
class SimpleTokenizer(Tokenizer):
def __init__(self, dict_dir):
pass
def tokenize(self, text):
return [token for token in simple_preprocess(text) if token not in STOPWORDS]
class SpacyTokenizer(Tokenizer):
def __init__(self, dict_dir, preprocess=False):
super().__init__()
self.nlp = spacy.load("en_core_web_sm")
self.dict_dir = dict_dir
self.preprocess = preprocess
self._dictionary = defaultdict()
self.i = 0
def _dictionary_exist(self):
return path.isfile(path.join(self.dict_dir, "dict.p"))
@property
def dictionary(self):
self._dictionary = pickle.load(open(path.join(self.dict_dir, "dict.p"), 'rb'))
return self._dictionary
def tokenize(self, text):
if self.preprocess:
docs_tokens = [token.text.lower() for token in self.nlp(text) if not token.is_stop]
else:
docs_tokens = [token.text for token in self.nlp(text)]
if not self._dictionary_exist():
for token in docs_tokens:
if token not in self._dictionary:
self._dictionary[self.i] = token
self.i += 1
return docs_tokens
def save_dict(self):
if self._dictionary:
pickle.dump(self._dictionary, open(path.join(self.dict_dir, "dict.p"), 'wb'))
class TransformerGPT2Tokenizer(Tokenizer):
def __init__(self, cached_dir):
super().__init__()
model_name_or_path = "gpt2" # 50257 tokens
tokenizer_class = GPT2TokenizerFast
#tokenizer_class = GPT2Tokenizer
self.tokenizer = tokenizer_class.from_pretrained(model_name_or_path, cache_dir=cached_dir)
@property
def dictionary(self):
#return self.tokenizer.__dict__
return self.tokenizer.decoder
def tokenize(self, text):
# return self.tokenizer.encode(text, add_special_tokens=False)
return self.tokenizer.tokenize(text)
def decode(self, id):
return self.tokenizer.decode(id)
def encode(self, text):
return self.tokenizer.encode(text)
def save_dict(self):
pass
if __name__ == "__main__":
tokenizer = TransformerGPT2Tokenizer(cached_dir="/home/rohola/codes/topical_language_generation/caches/")
tokens = tokenizer.tokenize("this is a test")
ids = tokenizer.encode("this is a test")
print(ids)
token = tokenizer.decode(ids)
print(token)