From de4f8f72bac3af244e3eb2ccb0ed67fa9a3e5f62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 May 2022 17:56:55 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/graphmask_explainer.py | 71 ++++++++++++++++----------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/examples/graphmask_explainer.py b/examples/graphmask_explainer.py index 3eaab454760d..7c5679683d0f 100644 --- a/examples/graphmask_explainer.py +++ b/examples/graphmask_explainer.py @@ -1,13 +1,16 @@ -import torch -import torch_geometric.transforms as T -from torch_geometric.datasets import Entities, TUDataset, Planetoid -from torch_geometric.loader import DataLoader -from torch_geometric.nn import global_add_pool, global_mean_pool -from torch_geometric.nn import GraphMaskExplainer import matplotlib.pyplot as plt +import torch import torch.nn.functional as F -from torch_geometric.nn.conv import GCNConv, GATConv, FastRGCNConv +import torch_geometric.transforms as T +from torch_geometric.datasets import Entities, Planetoid, TUDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import ( + GraphMaskExplainer, + global_add_pool, + global_mean_pool, +) +from torch_geometric.nn.conv import FastRGCNConv, GATConv, GCNConv if __name__ == "__main__": # GCN Node Classification================================================= @@ -44,22 +47,20 @@ def forward(self, x, edge_index, edge_weight): loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() - print('Epoch {:03d}'.format(epoch), - 'train_loss: {:.4f}'.format(loss)) + print('Epoch {:03d}'.format(epoch), 'train_loss: {:.4f}'.format(loss)) print("Optimization Finished!") # [1, 6, 3, 5, 12] explainer = GraphMaskExplainer(2, model_to_explain=model, epochs=4, allow_multiple_explanations=True) - edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1),)) + edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1), )) feat_mask = explainer.train_node_explainer([6, 12], data.x, data.edge_index, edge_weight=data.edge_weight) edge_mask = explainer.explain_node([6, 12], data.x, data.edge_index) - ax, G = explainer.visualize_subgraph([6, 12], - data.edge_index, edge_mask, - y=data.y, - edge_y=edge_y, node_alpha=feat_mask) + ax, G = explainer.visualize_subgraph([6, 12], data.edge_index, edge_mask, + y=data.y, edge_y=edge_y, + node_alpha=feat_mask) plt.show() # GAT Node Classification================================================= @@ -98,21 +99,20 @@ def forward(self, x, edge_index): loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() - print('Epoch {:03d}'.format(epoch), - 'train_loss: {:.4f}'.format(loss)) + print('Epoch {:03d}'.format(epoch), 'train_loss: {:.4f}'.format(loss)) print("Optimization Finished!") explainer = GraphMaskExplainer(2, model_to_explain=model, epochs=4, allow_multiple_explanations=True, layer_type='GAT') - edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1),)) - feat_mask = explainer.train_node_explainer([1, 6, 3, 5, 12], - data.x, data.edge_index) + edge_y = torch.randint(low=0, high=30, size=(data.edge_index.size(1), )) + feat_mask = explainer.train_node_explainer([1, 6, 3, 5, 12], data.x, + data.edge_index) edge_mask = explainer.explain_node([1, 6, 3, 5, 12], data.x, data.edge_index) - ax, G = explainer.visualize_subgraph([1, 6, 3, 5, 12], - data.edge_index, edge_mask, y=data.y, - edge_y=edge_y, node_alpha=feat_mask) + ax, G = explainer.visualize_subgraph([1, 6, 3, 5, 12], data.edge_index, + edge_mask, y=data.y, edge_y=edge_y, + node_alpha=feat_mask) plt.show() # R-GCN Node Classification=============================================== @@ -122,7 +122,7 @@ def forward(self, x, edge_index): dataset[0].train_y, dataset[0].test_y, \ dataset[0].train_idx, dataset[0].test_idx, dataset[0].num_nodes x = torch.randn(8285, 4) - edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1),)) + edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), )) class Net(torch.nn.Module): def __init__(self): @@ -153,8 +153,7 @@ def forward(self, x, edge_index, edge_type): loss = F.nll_loss(log_logits[train_mask], train_labels) loss.backward() optimizer.step() - print('Epoch {:03d}'.format(epoch), - 'train_loss: {:.4f}'.format(loss)) + print('Epoch {:03d}'.format(epoch), 'train_loss: {:.4f}'.format(loss)) print("Optimization Finished!") explainer = GraphMaskExplainer(3, epochs=4, model_to_explain=model, @@ -162,19 +161,17 @@ def forward(self, x, edge_index, edge_type): layer_type='FastRGCN') feat_mask = explainer.train_node_explainer([1, 5, 12], x, edge_index, edge_type=edge_type) - edge_mask = explainer.explain_node([1, 5, 12], x, - edge_index) - ax, G = explainer.visualize_subgraph([1, 5, 12], - edge_index, - edge_mask, y=None, - edge_y=edge_y, node_alpha=feat_mask) + edge_mask = explainer.explain_node([1, 5, 12], x, edge_index) + ax, G = explainer.visualize_subgraph([1, 5, 12], edge_index, edge_mask, + y=None, edge_y=edge_y, + node_alpha=feat_mask) plt.show() # GCN Graph Classification================================================ dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') loader = DataLoader(dataset, batch_size=100, shuffle=True) edge_y = torch.randint(low=0, high=30, - size=(dataset[4].edge_index.size(1),)) + size=(dataset[4].edge_index.size(1), )) class Net(torch.nn.Module): def __init__(self): @@ -231,7 +228,7 @@ def forward(self, x, edge_index, batch, edge_weight): dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') loader = DataLoader(dataset, batch_size=100, shuffle=True) edge_y = torch.randint(low=0, high=30, - size=(dataset[4].edge_index.size(1),)) + size=(dataset[4].edge_index.size(1), )) class Net(torch.nn.Module): def __init__(self): @@ -283,9 +280,9 @@ def forward(self, x, edge_index, batch): dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') loader = DataLoader(dataset, batch_size=100, shuffle=True) edge_type = torch.randint(low=0, high=90, - size=(dataset[4].edge_index.size(1),)) + size=(dataset[4].edge_index.size(1), )) edge_y = torch.randint(low=0, high=30, - size=(dataset[4].edge_index.size(1),)) + size=(dataset[4].edge_index.size(1), )) class Net(torch.nn.Module): def __init__(self): @@ -315,8 +312,8 @@ def forward(self, x, edge_index, batch, edge_type): for data in loader: model.train() optimizer.zero_grad() - data.edge_type = torch.randint( - low=0, high=90, size=(data.edge_index.size(1),)) + data.edge_type = torch.randint(low=0, high=90, + size=(data.edge_index.size(1), )) output = model(data.x, data.edge_index, data.batch, data.edge_type) loss = F.nll_loss(output, data.y) loss.backward()