Skip to content

submit GCNII #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions examples/GCNII/cora/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch.nn as nn
import torch
import math
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter

class GraphConvolution(nn.Module):

def __init__(self, in_features, out_features, residual=False, variant=False):
super(GraphConvolution, self).__init__()
self.variant = variant
if self.variant:
self.in_features = 2*in_features
else:
self.in_features = in_features

self.out_features = out_features
self.residual = residual
self.weight = Parameter(torch.FloatTensor(self.in_features,self.out_features))
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.out_features)
self.weight.data.uniform_(-stdv, stdv)

def forward(self, input, adj , h0 , lamda, alpha, l):
theta = math.log(lamda/l+1)
hi = torch.spmm(adj, input)
if self.variant:
support = torch.cat([hi,h0],1)
r = (1-alpha)*hi+alpha*h0
else:
support = (1-alpha)*hi+alpha*h0
r = support
output = theta*torch.mm(support, self.weight)+(1-theta)*r
if self.residual:
output = output+input
return output

class GCNII(nn.Module):
def __init__(self, nfeat, nlayers,nhidden, nclass, dropout, lamda, alpha, variant):
super(GCNII, self).__init__()
self.convs = nn.ModuleList()
for _ in range(nlayers):
self.convs.append(GraphConvolution(nhidden, nhidden,variant=variant))
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(nfeat, nhidden))
self.fcs.append(nn.Linear(nhidden, nclass))
self.params1 = list(self.convs.parameters())
self.params2 = list(self.fcs.parameters())
self.act_fn = nn.ReLU()
self.dropout = dropout
self.alpha = alpha
self.lamda = lamda

def forward(self, x, adj):
_layers = []
x = F.dropout(x, self.dropout, training=self.training)
layer_inner = self.act_fn(self.fcs[0](x))
_layers.append(layer_inner)
for i,con in enumerate(self.convs):
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.act_fn(con(layer_inner,adj,_layers[0],self.lamda,self.alpha,i+1))
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.fcs[-1](layer_inner)
return F.log_softmax(layer_inner, dim=1)

class GCNIIppi(nn.Module):
def __init__(self, nfeat, nlayers,nhidden, nclass, dropout, lamda, alpha,variant):
super(GCNIIppi, self).__init__()
self.convs = nn.ModuleList()
for _ in range(nlayers):
self.convs.append(GraphConvolution(nhidden, nhidden,variant=variant,residual=True))
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(nfeat, nhidden))
self.fcs.append(nn.Linear(nhidden, nclass))
self.act_fn = nn.ReLU()
self.sig = nn.Sigmoid()
self.dropout = dropout
self.alpha = alpha
self.lamda = lamda

def forward(self, x, adj):
_layers = []
x = F.dropout(x, self.dropout, training=self.training)
layer_inner = self.act_fn(self.fcs[0](x))
_layers.append(layer_inner)
for i,con in enumerate(self.convs):
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.act_fn(con(layer_inner,adj,_layers[0],self.lamda,self.alpha,i+1))
layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
layer_inner = self.sig(self.fcs[-1](layer_inner))
return layer_inner


if __name__ == '__main__':
pass






Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
141 changes: 141 additions & 0 deletions examples/GCNII/cora/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#GCNII
#paper:Simple and Deep Graph Convolutional Networks
#Arxiv:https://arxiv.org/abs/2007.02133
#acc:78
#Runtime:117.3718s(single 6G GPU)
#Usage:python train.py

from __future__ import division
from __future__ import print_function
import time
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from utils import *
from model import *
import sys
sys.path.append("../../../rllm/dataloader")
from load_data import load_data
import uuid

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=1500, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate.')
parser.add_argument('--wd1', type=float, default=0.01, help='weight decay (L2 loss on parameters).')
parser.add_argument('--wd2', type=float, default=5e-4, help='weight decay (L2 loss on parameters).')
parser.add_argument('--layer', type=int, default=64, help='Number of layers.')
parser.add_argument('--hidden', type=int, default=64, help='hidden dimensions.')
parser.add_argument('--dropout', type=float, default=0.6, help='Dropout rate (1 - keep probability).')
parser.add_argument('--patience', type=int, default=100, help='Patience')
parser.add_argument('--data', default='cora', help='dateset')
parser.add_argument('--dev', type=int, default=0, help='device id')
parser.add_argument('--alpha', type=float, default=0.05, help='alpha_l')
parser.add_argument('--lamda', type=float, default=0.5, help='lamda.')
parser.add_argument('--variant', action='store_true', default=False, help='GCN* model.')
parser.add_argument('--test', action='store_true', default=False, help='evaluation on test set.')
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# Load data
cudaid = "cuda:"+str(args.dev)
device = torch.device(cudaid)
#adj, features, labels,idx_train,idx_val,idx_test = load_citation(args.data)
data, adj, features, labels, idx_train, idx_val, idx_test = load_data('cora', device=device)
labels = labels.argmax(dim=-1)
features = features.to(device)
adj = adj.to(device)
checkpt_file = 'pretrained/'+uuid.uuid4().hex+'.pt'
print(cudaid,checkpt_file)

model = GCNII(nfeat=features.shape[1],
nlayers=args.layer,
nhidden=args.hidden,
nclass=int(labels.max()) + 1,
dropout=args.dropout,
lamda = args.lamda,
alpha=args.alpha,
variant=args.variant).to(device)

optimizer = optim.Adam([
{'params':model.params1,'weight_decay':args.wd1},
{'params':model.params2,'weight_decay':args.wd2},
],lr=args.lr)

def train():
model.train()
optimizer.zero_grad()
output = model(features,adj)
print(output[idx_train].size())
print(labels[idx_train].size())
acc_train = accuracy(output[idx_train], labels[idx_train].to(device))
loss_train = F.nll_loss(output[idx_train], labels[idx_train].to(device))
loss_train.backward()
optimizer.step()
return loss_train.item(),acc_train.item()


def validate():
model.eval()
with torch.no_grad():
output = model(features,adj)
loss_val = F.nll_loss(output[idx_val], labels[idx_val].to(device))
acc_val = accuracy(output[idx_val], labels[idx_val].to(device))
return loss_val.item(),acc_val.item()

def test():
model.load_state_dict(torch.load(checkpt_file))
model.eval()
with torch.no_grad():
output = model(features, adj)
loss_test = F.nll_loss(output[idx_test], labels[idx_test].to(device))
acc_test = accuracy(output[idx_test], labels[idx_test].to(device))
return loss_test.item(),acc_test.item()

t_total = time.time()
bad_counter = 0
best = 999999999
best_epoch = 0
acc = 0
for epoch in range(args.epochs):
loss_tra,acc_tra = train()
loss_val,acc_val = validate()
if(epoch+1)%1 == 0:
print('Epoch:{:04d}'.format(epoch+1),
'train',
'loss:{:.3f}'.format(loss_tra),
'acc:{:.2f}'.format(acc_tra*100),
'| val',
'loss:{:.3f}'.format(loss_val),
'acc:{:.2f}'.format(acc_val*100))
if loss_val < best:
best = loss_val
best_epoch = epoch
acc = acc_val
torch.save(model.state_dict(), checkpt_file)
bad_counter = 0
else:
bad_counter += 1

if bad_counter == args.patience:
break

if args.test:
acc = test()[1]

print("Train cost: {:.4f}s".format(time.time() - t_total))
print('Load {}th epoch'.format(best_epoch))
print("Test" if args.test else "Val","acc.:{:.1f}".format(acc*100))






Loading