Skip to content

Commit

Permalink
Merge pull request Diego999#23 from pbloem/master
Browse files Browse the repository at this point in the history
Update to pytorch 1.0
  • Loading branch information
Diego999 authored Jul 10, 2019
2 parents 3301fd6 + a96fc3c commit e6a8fa5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
self.special_spmm = SpecialSpmm()

def forward(self, input, adj):
dv = 'cuda' if input.is_cuda else 'cpu'

N = input.size()[0]
edge = adj.nonzero().t()

Expand All @@ -112,7 +114,7 @@ def forward(self, input, adj):
assert not torch.isnan(edge_e).any()
# edge_e: E

e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1)).cuda())
e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1), device=dv))
# e_rowsum: N x 1

edge_e = self.dropout(edge_e)
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def train(epoch):
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
acc_val = accuracy(output[idx_val], labels[idx_val])
print('Epoch: {:04d}'.format(epoch+1),
'loss_train: {:.4f}'.format(loss_train.data[0]),
'acc_train: {:.4f}'.format(acc_train.data[0]),
'loss_val: {:.4f}'.format(loss_val.data[0]),
'acc_val: {:.4f}'.format(acc_val.data[0]),
'loss_train: {:.4f}'.format(loss_train.data.item()),
'acc_train: {:.4f}'.format(acc_train.data.item()),
'loss_val: {:.4f}'.format(loss_val.data.item()),
'acc_val: {:.4f}'.format(acc_val.data.item()),
'time: {:.4f}s'.format(time.time() - t))

return loss_val.data[0]
return loss_val.data.item()


def compute_test():
Expand Down

0 comments on commit e6a8fa5

Please sign in to comment.