-
Notifications
You must be signed in to change notification settings - Fork 1
/
my_skipgram.py
109 lines (84 loc) · 3.42 KB
/
my_skipgram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
from text_processing import get_nlp_pipeline, word_tokenization
# Custom Skip-gram
def get_word2idx(vocab):
word_to_ix = {}
word_to_ix['PAD'] = 0
for idx, word in enumerate(vocab):
word_to_ix[word] = idx+1
return word_to_ix
def get_idx2word(vocab):
ix_to_word = {}
ix_to_word[0] = 'PAD'
for idx, word in enumerate(vocab):
ix_to_word[idx+1] = word
return ix_to_word
def build_input_output(_input_text, window):
io_pair = []
for _input_tokens in _input_text:
for idx, word in enumerate(_input_tokens):
context = []
start = idx - window
end = idx + window + 1
for cur in range(start,end):
if cur < 0:
context.append('PAD')
if cur != idx and cur >= 0 and cur < len(_input_tokens):
context.append(_input_tokens[cur])
if cur >= len(_input_tokens):
context.append('PAD')
io_pair.append([word, context])
return io_pair
def make_context_vector(context, word_to_ix):
return torch.tensor(word_to_ix[context], dtype=torch.long)
class custom_skipgram(torch.nn.Module):
def __init__(self, vocab_size, hidden_dim, embedding_dim, window):
super(custom_skipgram, self).__init__()
# initialize lookup table
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.window = window
# projection layer
self.linear1 = nn.Linear(embedding_dim, hidden_dim)
self.activation_function1 = nn.ReLU()
# output layer
self.linear2 = nn.Linear(hidden_dim, vocab_size*window*2)
self.activation_function2 = nn.LogSoftmax(dim=-1)
def forward(self, _inputs):
embeds = self.embeddings(_inputs)
out = self.linear1(embeds)
projected_out = self.activation_function1(out)
output = self.linear2(projected_out)
log_probs = self.activation_function2(output).view(self.window*2,-1)
return projected_out, log_probs
def build_vocab(text_list, selected_nlp_pipeline, nlp_pipeline):
input_tokens = []
vocab = []
for _text in text_list:
tokenized_text = word_tokenization(_text, selected_nlp_pipeline, nlp_pipeline)
vocab += tokenized_text
input_tokens.append(tokenized_text)
return vocab, input_tokens
def train_custom_skipgram_model(model, input_tokens, window, word_to_ix):
input_output_context = build_input_output(input_tokens, window)
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(10):
total_loss = 0
for context, target in input_output_context:
context_vector = make_context_vector(context, word_to_ix)
_, log_probs = model(context_vector)
total_loss += loss_function(log_probs, torch.tensor([word_to_ix[t] for t in target]))
#optimize at the end of each epoch
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return model
def get_custom_word_embeddings(model, cur_text, selected_nlp_pipeline, nlp_pipeline, word_to_ix):
tks = word_tokenization(cur_text, selected_nlp_pipeline, nlp_pipeline)
embeddings = list(model.parameters())[0]
embeddings = embeddings.cpu().detach()
result = []
for tk in tks:
result.append(embeddings[word_to_ix[tk]])
return result