-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
85 lines (68 loc) · 3.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: utf-8 -*-
"""
Created on 4/4/2019
@author: RuihongQiu
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GatedGraphConv, SAGEConv
class Embedding2Score(nn.Module):
def __init__(self, hidden_size):
super(Embedding2Score, self).__init__()
self.hidden_size = hidden_size
self.W_1 = nn.Linear(self.hidden_size, self.hidden_size)
self.W_2 = nn.Linear(self.hidden_size, self.hidden_size)
self.q = nn.Linear(self.hidden_size, 1)
self.W_3 = nn.Linear(2 * self.hidden_size, self.hidden_size)
def forward(self, session_embedding, all_item_embedding, batch):
sections = torch.bincount(batch)
v_i = torch.split(session_embedding, tuple(sections.cpu().numpy())) # split whole x back into graphs G_i
v_n_repeat = tuple(nodes[-1].view(1, -1).repeat(nodes.shape[0], 1) for nodes in v_i) # repeat |V|_i times for the last node embedding
# Eq(6)
alpha = self.q(torch.sigmoid(self.W_1(torch.cat(v_n_repeat, dim=0)) + self.W_2(session_embedding))) # |V|_i * 1
s_g_whole = alpha * session_embedding # |V|_i * hidden_size
s_g_split = torch.split(s_g_whole, tuple(sections.cpu().numpy())) # split whole s_g into graphs G_i
s_g = tuple(torch.sum(embeddings, dim=0).view(1, -1) for embeddings in s_g_split)
# Eq(7)
v_n = tuple(nodes[-1].view(1, -1) for nodes in v_i)
s_h = self.W_3(torch.cat((torch.cat(v_n, dim=0), torch.cat(s_g, dim=0)), dim=1))
# Eq(8)
z_i_hat = torch.mm(s_h, all_item_embedding.weight.transpose(1, 0))
return z_i_hat
class GNNModel(nn.Module):
"""
Args:
hidden_size: the number of units in a hidden layer.
n_node: the number of items in the whole item set for embedding layer.
"""
def __init__(self, hidden_size, n_node):
super(GNNModel, self).__init__()
self.hidden_size, self.n_node = hidden_size, n_node
self.embedding = nn.Embedding(self.n_node, self.hidden_size)
self.gat1 = GATConv(self.hidden_size, self.hidden_size, heads=8, negative_slope=0.2)
self.gat2 = GATConv(8 * self.hidden_size, self.hidden_size, heads=1, negative_slope=0.2)
self.sage1 = SAGEConv(self.hidden_size, self.hidden_size)
self.sage2 = SAGEConv(self.hidden_size, self.hidden_size)
self.gated = GatedGraphConv(self.hidden_size, num_layers=2)
self.e2s = Embedding2Score(self.hidden_size)
self.loss_function = nn.CrossEntropyLoss()
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, data):
x, edge_index, batch, edge_attr = data.x - 1, data.edge_index, data.batch, data.edge_attr
# print(torch.sum(data.sequence_len), data.sequence.shape[0]) # sequence是batch下全部concat
# print(x.squeeze(dim=-1).unique().shape, x.shape) # x是batch下總共有多少edge
embedding = self.embedding(x).squeeze()
# GGNN
# hidden = self.gated(embedding, edge_index)
# GAT
hidden = F.relu(self.gat1(embedding, edge_index))
hidden = self.gat2(hidden, edge_index)
# hidden1 = F.relu(self.sage1(embedding, edge_index))
# hidden2 = F.relu(self.sage2(hidden1, edge_index))
return self.e2s(hidden, self.embedding, batch)