forked from codertimo/BERT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvocab.py
189 lines (151 loc) · 6.64 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import pickle
import tqdm
from collections import Counter
class TorchVocab(object):
"""Defines a vocabulary object that will be used to numericalize a field.
Attributes:
freqs: A collections.Counter object holding the frequencies of tokens
in the data used to build the Vocab.
stoi: A collections.defaultdict instance mapping token strings to
numerical identifiers.
itos: A list of token strings indexed by their numerical identifiers.
"""
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
vectors=None, unk_init=None, vectors_cache=None):
"""Create a Vocab object from a collections.Counter.
Arguments:
counter: collections.Counter object holding the frequencies of
each value found in the data.
max_size: The maximum size of the vocabulary, or None for no
maximum. Default: None.
min_freq: The minimum frequency needed to include a token in the
vocabulary. Values less than 1 will be set to 1. Default: 1.
specials: The list of special tokens (e.g., padding or eos) that
will be prepended to the vocabulary in addition to an <unk>
token. Default: ['<pad>']
vectors: One of either the available pretrained vectors
or custom pretrained vectors (see Vocab.load_vectors);
or a list of aforementioned vectors
unk_init (callback): by default, initialize out-of-vocabulary word vectors
to zero vectors; can be any function that takes in a Tensor and
returns a Tensor of the same size. Default: torch.Tensor.zero_
vectors_cache: directory for cached vectors. Default: '.vector_cache'
"""
self.freqs = counter
counter = counter.copy()
min_freq = max(min_freq, 1)
self.itos = list(specials)
# frequencies of special tokens are not counted when building vocabulary
# in frequency order
for tok in specials:
del counter[tok]
max_size = None if max_size is None else max_size + len(self.itos)
# sort by frequency, then alphabetically
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
for word, freq in words_and_frequencies:
if freq < min_freq or len(self.itos) == max_size:
break
self.itos.append(word)
# stoi is simply a reverse dict for itos
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
self.vectors = None
if vectors is not None:
self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
else:
assert unk_init is None and vectors_cache is None
def __eq__(self, other):
if self.freqs != other.freqs:
return False
if self.stoi != other.stoi:
return False
if self.itos != other.itos:
return False
if self.vectors != other.vectors:
return False
return True
def __len__(self):
return len(self.itos)
def vocab_rerank(self):
self.stoi = {word: i for i, word in enumerate(self.itos)}
def extend(self, v, sort=False):
words = sorted(v.itos) if sort else v.itos
for w in words:
if w not in self.stoi:
self.itos.append(w)
self.stoi[w] = len(self.itos) - 1
class Vocab(TorchVocab):
def __init__(self, counter, max_size=None, min_freq=1):
self.pad_index = 0
self.unk_index = 1
self.eos_index = 2
self.sos_index = 3
self.mask_index = 4
super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"],
max_size=max_size, min_freq=min_freq)
def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
pass
def from_seq(self, seq, join=False, with_pad=False):
pass
@staticmethod
def load_vocab(vocab_path: str) -> 'Vocab':
with open(vocab_path, "rb") as f:
return pickle.load(f)
def save_vocab(self, vocab_path):
with open(vocab_path, "wb") as f:
pickle.dump(self, f)
# Building Vocab with text files
class WordVocab(Vocab):
def __init__(self, texts, max_size=None, min_freq=1):
print("Building Vocab")
counter = Counter()
for line in tqdm.tqdm(texts):
if isinstance(line, list):
words = line
else:
words = line.replace("\n", "").replace("\t", "").strip().split()
for word in words:
counter[word] += 1
super().__init__(counter, max_size=max_size, min_freq=min_freq)
def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
if isinstance(sentence, str):
sentence = sentence.split()
seq = [self.stoi.get(word, self.unk_index) for word in sentence]
if with_eos:
seq += [self.eos_index] # this would be index 1
if with_sos:
seq = [self.sos_index] + seq
origin_seq_len = len(seq)
if seq_len is None:
pass
elif len(seq) <= seq_len:
seq += [self.pad_index for _ in range(seq_len - len(seq))]
else:
seq = seq[:seq_len]
return (seq, origin_seq_len) if with_len else seq
def from_seq(self, seq, join=False, with_pad=False):
words = [self.itos[idx]
if idx < len(self.itos)
else "<%d>" % idx
for idx in seq
if not with_pad or idx != self.pad_index]
return " ".join(words) if join else words
@staticmethod
def load_vocab(vocab_path: str) -> 'WordVocab':
with open(vocab_path, "rb") as f:
return pickle.load(f)
def build():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--corpus_path", required=True, type=str)
parser.add_argument("-o", "--output_path", required=True, type=str)
parser.add_argument("-s", "--vocab_size", type=int, default=None)
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
parser.add_argument("-m", "--min_freq", type=int, default=1)
args = parser.parse_args()
with open(args.corpus_path, "r", encoding=args.encoding) as f:
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)
print("VOCAB SIZE:", len(vocab))
vocab.save_vocab(args.output_path)
if __name__ == '__main__':
build()