-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathStatistic.py
92 lines (70 loc) · 2.45 KB
/
Statistic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import os
from sklearn.decomposition import PCA
from StorageHandler import StorageHandler
from gensim.scripts.glove2word2vec import glove2word2vec
from gensim.test.utils import datapath, get_tmpfile
from gensim.models import KeyedVectors
import plotly.graph_objs as go
import pickle
class Statistic:
@staticmethod
def __load_words(words, words_dict):
test_words = [word.strip() for word in words]
test_vectors_indexes = [words_dict[word] for word in test_words]
return test_vectors_indexes
@staticmethod
def plot_haidt_embeddings(model: dict, data: dict):
print("[PCA] Fitting model")
words = list(model.keys())
word_vectors = list(model.values())
words_index = dict(zip(words, list(range(len(words)))))
pca = PCA(n_components = 3, random_state=42)
components = pca.fit_transform(word_vectors)
traces = []
print("[PCA] Building 3D embedding space")
for haidt, words in data.items():
indexes = Statistic.__load_words(words, words_index)
color = len(list(traces))
trace = go.Scatter3d(
x = components[indexes,0],
y = components[indexes,1],
z = components[indexes,2],
text = words,
name = haidt,
textposition = "top center",
textfont_size = 20,
mode = 'markers+text',
marker = {
'size': 10,
'opacity': 0.8,
'color': color
}
)
traces.append(trace)
layout = go.Layout(
margin = {'l': 0, 'r': 0, 'b': 0, 't': 0},
showlegend=True,
legend=dict(
x=1,
y=0.5,
font=dict(
family="Courier New",
size=25,
color="black"
)),
font = dict(
family = " Courier New ",
size = 15),
autosize = False,
width = 1500,
height = 1000
)
plot_figure = go.Figure(data = traces, layout = layout)
plot_figure.show()
if __name__ == "__main__":
data = {
'Harm': ['kill','fist','accident'],
'Care': ['important','love','felling']
}
Statistic.plot_haidt_embeddings(data)