Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 10, 2022
1 parent 3ff0f1a commit de4f8f7
Showing 1 changed file with 34 additions and 37 deletions.
71 changes: 34 additions & 37 deletions examples/graphmask_explainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Entities, TUDataset, Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_geometric.nn import GraphMaskExplainer
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn.conv import GCNConv, GATConv, FastRGCNConv

import torch_geometric.transforms as T
from torch_geometric.datasets import Entities, Planetoid, TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
GraphMaskExplainer,
global_add_pool,
global_mean_pool,
)
from torch_geometric.nn.conv import FastRGCNConv, GATConv, GCNConv

if __name__ == "__main__":
# GCN Node Classification=================================================
Expand Down Expand Up @@ -44,22 +47,20 @@ def forward(self, x, edge_index, edge_weight):
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
print('Epoch {:03d}'.format(epoch),
'train_loss: {:.4f}'.format(loss))
print('Epoch {:03d}'.format(epoch), 'train_loss: {:.4f}'.format(loss))
print("Optimization Finished!")

# [1, 6, 3, 5, 12]
explainer = GraphMaskExplainer(2, model_to_explain=model, epochs=4,
allow_multiple_explanations=True)
edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1),))
edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1), ))
feat_mask = explainer.train_node_explainer([6, 12], data.x,
data.edge_index,
edge_weight=data.edge_weight)
edge_mask = explainer.explain_node([6, 12], data.x, data.edge_index)
ax, G = explainer.visualize_subgraph([6, 12],
data.edge_index, edge_mask,
y=data.y,
edge_y=edge_y, node_alpha=feat_mask)
ax, G = explainer.visualize_subgraph([6, 12], data.edge_index, edge_mask,
y=data.y, edge_y=edge_y,
node_alpha=feat_mask)
plt.show()

# GAT Node Classification=================================================
Expand Down Expand Up @@ -98,21 +99,20 @@ def forward(self, x, edge_index):
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
print('Epoch {:03d}'.format(epoch),
'train_loss: {:.4f}'.format(loss))
print('Epoch {:03d}'.format(epoch), 'train_loss: {:.4f}'.format(loss))
print("Optimization Finished!")

explainer = GraphMaskExplainer(2, model_to_explain=model, epochs=4,
allow_multiple_explanations=True,
layer_type='GAT')
edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1),))
feat_mask = explainer.train_node_explainer([1, 6, 3, 5, 12],
data.x, data.edge_index)
edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1), ))
feat_mask = explainer.train_node_explainer([1, 6, 3, 5, 12], data.x,
data.edge_index)
edge_mask = explainer.explain_node([1, 6, 3, 5, 12], data.x,
data.edge_index)
ax, G = explainer.visualize_subgraph([1, 6, 3, 5, 12],
data.edge_index, edge_mask, y=data.y,
edge_y=edge_y, node_alpha=feat_mask)
ax, G = explainer.visualize_subgraph([1, 6, 3, 5, 12], data.edge_index,
edge_mask, y=data.y, edge_y=edge_y,
node_alpha=feat_mask)
plt.show()

# R-GCN Node Classification===============================================
Expand All @@ -122,7 +122,7 @@ def forward(self, x, edge_index):
dataset[0].train_y, dataset[0].test_y, \
dataset[0].train_idx, dataset[0].test_idx, dataset[0].num_nodes
x = torch.randn(8285, 4)
edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1),))
edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), ))

class Net(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -153,28 +153,25 @@ def forward(self, x, edge_index, edge_type):
loss = F.nll_loss(log_logits[train_mask], train_labels)
loss.backward()
optimizer.step()
print('Epoch {:03d}'.format(epoch),
'train_loss: {:.4f}'.format(loss))
print('Epoch {:03d}'.format(epoch), 'train_loss: {:.4f}'.format(loss))
print("Optimization Finished!")

explainer = GraphMaskExplainer(3, epochs=4, model_to_explain=model,
allow_multiple_explanations=True,
layer_type='FastRGCN')
feat_mask = explainer.train_node_explainer([1, 5, 12], x, edge_index,
edge_type=edge_type)
edge_mask = explainer.explain_node([1, 5, 12], x,
edge_index)
ax, G = explainer.visualize_subgraph([1, 5, 12],
edge_index,
edge_mask, y=None,
edge_y=edge_y, node_alpha=feat_mask)
edge_mask = explainer.explain_node([1, 5, 12], x, edge_index)
ax, G = explainer.visualize_subgraph([1, 5, 12], edge_index, edge_mask,
y=None, edge_y=edge_y,
node_alpha=feat_mask)
plt.show()

# GCN Graph Classification================================================
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=100, shuffle=True)
edge_y = torch.randint(low=0, high=30,
size=(dataset[4].edge_index.size(1),))
size=(dataset[4].edge_index.size(1), ))

class Net(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -231,7 +228,7 @@ def forward(self, x, edge_index, batch, edge_weight):
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=100, shuffle=True)
edge_y = torch.randint(low=0, high=30,
size=(dataset[4].edge_index.size(1),))
size=(dataset[4].edge_index.size(1), ))

class Net(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -283,9 +280,9 @@ def forward(self, x, edge_index, batch):
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=100, shuffle=True)
edge_type = torch.randint(low=0, high=90,
size=(dataset[4].edge_index.size(1),))
size=(dataset[4].edge_index.size(1), ))
edge_y = torch.randint(low=0, high=30,
size=(dataset[4].edge_index.size(1),))
size=(dataset[4].edge_index.size(1), ))

class Net(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -315,8 +312,8 @@ def forward(self, x, edge_index, batch, edge_type):
for data in loader:
model.train()
optimizer.zero_grad()
data.edge_type = torch.randint(
low=0, high=90, size=(data.edge_index.size(1),))
data.edge_type = torch.randint(low=0, high=90,
size=(data.edge_index.size(1), ))
output = model(data.x, data.edge_index, data.batch, data.edge_type)
loss = F.nll_loss(output, data.y)
loss.backward()
Expand Down

0 comments on commit de4f8f7

Please sign in to comment.