-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
executable file
·111 lines (71 loc) · 4.13 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
from config import N_TARGET_NODES_F, N_SOURCE_NODES_F,N_TARGET_NODES,N_SOURCE_NODES
class Aligner(torch.nn.Module):
def __init__(self):
super(Aligner, self).__init__()
nn = Sequential(Linear(1, N_SOURCE_NODES*N_SOURCE_NODES), ReLU())
self.conv1 = NNConv(N_SOURCE_NODES, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
self.conv11 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
nn = Sequential(Linear(1, N_SOURCE_NODES), ReLU())
self.conv2 = NNConv(N_SOURCE_NODES, 1, nn, aggr='mean', root_weight=True, bias=True)
self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
nn = Sequential(Linear(1, N_SOURCE_NODES), ReLU())
self.conv3 = NNConv(1, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
self.conv33 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
def forward(self, data):
x, edge_index, edge_attr = data.x, data.pos_edge_index, data.edge_attr
x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
x1 = F.dropout(x1, training=self.training)
x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
x2 = F.dropout(x2, training=self.training)
x3 = torch.cat([F.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1)
x4 = x3[:, 0:N_SOURCE_NODES]
x5 = x3[:, N_SOURCE_NODES:2*N_SOURCE_NODES]
x6 = (x4 + x5) / 2
return x6
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
nn = Sequential(Linear(1, N_SOURCE_NODES*N_SOURCE_NODES),ReLU())
self.conv1 = NNConv(N_SOURCE_NODES, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
self.conv11 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
nn = Sequential(Linear(1, N_TARGET_NODES*N_SOURCE_NODES), ReLU())
self.conv2 = NNConv(N_TARGET_NODES, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
self.conv22 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
nn = Sequential(Linear(1, N_TARGET_NODES*N_SOURCE_NODES), ReLU())
self.conv3 = NNConv(N_SOURCE_NODES, N_TARGET_NODES, nn, aggr='mean', root_weight=True, bias=True)
self.conv33 = BatchNorm(N_TARGET_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)
# self.layer= torch.nn.ConvTranspose2d(N_TARGET_NODES, N_TARGET_NODES,5)
def forward(self, data):
x, edge_index, edge_attr = data.x, data.pos_edge_index, data.edge_attr
# x = torch.squeeze(x)
x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
x1 = F.dropout(x1, training=self.training)
# x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
# x2 = F.dropout(x2, training=self.training)
x3 = F.sigmoid(self.conv33(self.conv3(x1, edge_index, edge_attr)))
x3 = F.dropout(x3, training=self.training)
x4 = torch.matmul(x3.t(), x3)
return x4
class Discriminator(torch.nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = GCNConv(N_TARGET_NODES, N_TARGET_NODES, cached=True)
self.conv2 = GCNConv(N_TARGET_NODES, 1, cached=True)
def forward(self, data):
x, edge_index, edge_attr = data.x, data.pos_edge_index, data.edge_attr
x = torch.squeeze(x)
x1 = F.sigmoid(self.conv1(x, edge_index))
x1 = F.dropout(x1, training=self.training)
x2 = F.sigmoid(self.conv2(x1, edge_index))
# # x2 = F.dropout(x2, training=self.training)
return x2