-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsynonym.py
116 lines (99 loc) · 3.41 KB
/
synonym.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from sklearn.cluster import KMeans
import spacy
import torch
import numpy as np
from time import time
from dataset import dataset
from path import Path
from utils import *
def get_synonym(DATASET, m=5):
'''
given a dataset,
return the synonym of each alphabet with number m.
'''
start_time = time()
_dataset = dataset(DATASET, True)
vocab = _dataset.vocab
words_number = len(vocab)
alphabet = range(words_number)
ori_words = vocab.lookup_tokens(alphabet)
nlp = spacy.load('en_core_web_lg')
vector_size = torch.tensor(nlp('Hello').vector).shape[0]
vectors = torch.zeros((words_number, vector_size))
has_vector = []
for idx, word in enumerate(ori_words):
_word = nlp(word)
if _word.has_vector and _word.vector_norm:
vectors[idx] = torch.from_numpy(_word.vector)
has_vector.append(True)
else:
has_vector.append(False)
#print(word, _word.vector)
assert len(has_vector) == words_number
vectors.to(dev())
all_synonym = []
for idx, word in enumerate(ori_words):
if has_vector[idx]:
synonym = []
vec = vectors[idx]
diff = vectors - vec
diff = (diff * diff).sum(dim = 1)
sorted, indices = torch.sort(diff)
index = 0
while len(synonym)<m:
if has_vector[indices[index]]:
synonym.append(indices[index])
index += 1
else:
synonym = [-1] * m
synonym = torch.tensor(synonym).to(dev())
all_synonym.append(synonym)
if idx % 1000 == 0:
print(f'{idx} words checked.')
all_synonym = torch.stack(all_synonym).to(dev())
torch.save(all_synonym, Path+DATASET+'_synonym.pth')
print(f'{DATASET} synonym ready. Use time:{time()-start_time:.1f}')
def get_alphabet(DATASET, CLUSTER):
start_time = time()
_dataset = dataset(DATASET, True)
vocab = _dataset.vocab
words_number = len(vocab)
ori_alphabet = range(words_number)
ori_words = vocab.lookup_tokens(ori_alphabet)
nlp = spacy.load('en_core_web_lg')
vector_size = torch.tensor(nlp('Hello').vector).shape[0]
vectors = []
has_vector = []
for idx, word in enumerate(ori_words):
_word = nlp(word)
if _word.has_vector and _word.vector_norm:
vectors.append(torch.from_numpy(_word.vector))
has_vector.append(True)
else:
has_vector.append(False)
#print(word, _word.vector)
assert len(has_vector) == words_number
vectors = torch.stack(vectors)
vectors = vectors.numpy()
print(f'vectors ready. Use time:{time()-start_time:.1f}')
current_time = time()
kmeans = KMeans(n_clusters=CLUSTER).fit(vectors)
print(f'kmeans ready. Use time:{time()-current_time:.1f}')
alphabet = []
index = 0
offset = 0
for idx, word in enumerate(ori_words):
if has_vector[idx]:
alphabet.append(kmeans.labels_[index])
index += 1
else:
alphabet.append(CLUSTER + offset)
offset += 1
alphabet=torch.tensor(alphabet)
torch.save(alphabet, Path+DATASET+'_alphabet.pth')
print(f'{DATASET} alphabet ready. Use time:{time()-start_time:.1f}')
if __name__ == '__main__':
alphabet = torch.load(Path+'news_alphabet.pth')
_dataset = dataset('news', True)
vocab = _dataset.vocab
pass