-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathmodel.py
45 lines (38 loc) · 1.54 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from inputdata import Options, scorefunction
class skipgram(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(skipgram, self).__init__()
self.u_embeddings = nn.Embedding(vocab_size, embedding_dim, sparse=True)
self.v_embeddings = nn.Embedding(vocab_size, embedding_dim, sparse=True)
self.embedding_dim = embedding_dim
self.init_emb()
def init_emb(self):
initrange = 0.5 / self.embedding_dim
self.u_embeddings.weight.data.uniform_(-initrange, initrange)
self.v_embeddings.weight.data.uniform_(-0, 0)
def forward(self, u_pos, v_pos, v_neg, batch_size):
embed_u = self.u_embeddings(u_pos)
embed_v = self.v_embeddings(v_pos)
score = torch.mul(embed_u, embed_v)
score = torch.sum(score, dim=1)
log_target = F.logsigmoid(score).squeeze()
neg_embed_v = self.v_embeddings(v_neg)
neg_score = torch.bmm(neg_embed_v, embed_u.unsqueeze(2)).squeeze()
neg_score = torch.sum(neg_score, dim=1)
sum_log_sampled = F.logsigmoid(-1*neg_score).squeeze()
loss = log_target + sum_log_sampled
return -1*loss.sum()/batch_size
def input_embeddings(self):
return self.u_embeddings.weight.data.cpu().numpy()
def save_embedding(self, file_name, id2word):
embeds = self.u_embeddings.weight.data
fo = open(file_name, 'w')
for idx in range(len(embeds)):
word = id2word(idx)
embed = ' '.join(embeds[idx])
fo.write(word+' '+embed+'\n')