-
Notifications
You must be signed in to change notification settings - Fork 1
/
alphabet.py
63 lines (54 loc) · 1.88 KB
/
alphabet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
"""
Author: bin zhou
Date: 2020-01-13
"""
class Alphabet(object):
def __init__(self, name, label=False, keep_growing=True):
self.name = name
self.UNKNOWN = '</unk>'
self.label = label
self.instance2index = dict()
self.instances = list()
self.keep_growing = keep_growing
self.default_index = 0
self.next_index = 1
# 将'</unk>'放在alphabet放在第一个
if not self.label:
self.add(self.UNKNOWN)
def add(self, instance):
if instance not in self.instance2index:
self.instances.append(instance)
self.instance2index[instance] = self.next_index
self.next_index += 1
def size(self):
return len(self.instances) + 1
def get_index(self, instance, mode='train'):
# 当出现新当label时当处理方式:
if mode == 'no_train' and self.label and instance not in self.instance2index:
self.keep_growing = True
self.next_index = len(self.instances) + 1
try:
return self.instance2index[instance]
except KeyError:
if self.keep_growing:
index = self.next_index
self.add(instance)
return index
else:
return self.instance2index[self.UNKNOWN]
def get_instance(self, index):
if index == 0:
if self.label:
return self.instances[0]
# First index is occupied by the wildcard element.
return None
try:
return self.instances[index - 1]
except IndexError:
print('WARNING:Alphabet get_instance ,unknown instance, return the first label.')
return self.instances[0]
def iteritems(self):
return self.instance2index.items()
def close(self):
self.keep_growing = False