Skip to content

Commit

Permalink
Fix norm in examples and dmon_pool (pyg-team#4959)
Browse files Browse the repository at this point in the history
* Fix norm in examples and dmon_pool

* changelog

* update

Co-authored-by: Guohao Li <lighaime@gmail.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jul 12, 2022
1 parent bac7021 commit a8601aa
Show file tree
Hide file tree
Showing 14 changed files with 27 additions and 48 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `unbatch_edge_index` functionality for splitting an `edge_index` tensor according to a `batch` vector ([#4903](https://github.com/pyg-team/pytorch_geometric/pull/4903))
- Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926), [#4951](https://github.com/pyg-team/pytorch_geometric/pull/4951), [#4958](https://github.com/pyg-team/pytorch_geometric/pull/4958))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926), [#4951](https://github.com/pyg-team/pytorch_geometric/pull/4951), [#4958](https://github.com/pyg-team/pytorch_geometric/pull/4958), [#4959](https://github.com/pyg-team/pytorch_geometric/pull/4959))
- Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927))
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
Expand Down
2 changes: 1 addition & 1 deletion examples/correct_and_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP([dataset.num_features, 200, 200, dataset.num_classes], dropout=0.5,
batch_norm=True, act_first=True).to(device)
norm="batch_norm", act_first=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

Expand Down
3 changes: 1 addition & 2 deletions examples/dgcnn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def __init__(self, out_channels, k=20, aggr='max'):
self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128]), k, aggr)
self.lin1 = Linear(128 + 64, 1024)

self.mlp = MLP([1024, 512, 256, out_channels], dropout=0.5,
batch_norm=False)
self.mlp = MLP([1024, 512, 256, out_channels], dropout=0.5, norm=None)

def forward(self, data):
pos, batch = data.pos, data.batch
Expand Down
2 changes: 1 addition & 1 deletion examples/dgcnn_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, out_channels, k=30, aggr='max'):
self.conv3 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr)

self.mlp = MLP([3 * 64, 1024, 256, 128, out_channels], dropout=0.5,
batch_norm=False)
norm=None)

def forward(self, data):
x, pos, batch = data.x, data.pos, data.batch
Expand Down
2 changes: 1 addition & 1 deletion examples/glnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
gnn = GCN(dataset.num_node_features, hidden_channels=16,
out_channels=dataset.num_classes, num_layers=2).to(device)
mlp = MLP([dataset.num_node_features, 64, dataset.num_classes], dropout=0.5,
batch_norm=False).to(device)
norm=None).to(device)

gnn_optimizer = torch.optim.Adam(gnn.parameters(), lr=0.01, weight_decay=5e-4)
mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4)
Expand Down
2 changes: 1 addition & 1 deletion examples/mutag_gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
in_channels = hidden_channels

self.mlp = MLP([hidden_channels, hidden_channels, out_channels],
batch_norm=False, dropout=0.5)
norm=None, dropout=0.5)

def forward(self, x, edge_index, batch):
for conv in self.convs:
Expand Down
26 changes: 7 additions & 19 deletions examples/point_transformer_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@

import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d as BN
from torch.nn import Identity
from torch.nn import Linear as Lin
from torch.nn import ReLU
from torch.nn import Sequential as Seq
from torch_cluster import fps, knn_graph
from torch_scatter import scatter_max

import torch_geometric.transforms as T
from torch_geometric.datasets import ModelNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import MLP, global_mean_pool
from torch_geometric.nn.conv import PointTransformerConv
from torch_geometric.nn.pool import knn

Expand All @@ -31,9 +27,10 @@ def __init__(self, in_channels, out_channels):
self.lin_in = Lin(in_channels, in_channels)
self.lin_out = Lin(out_channels, out_channels)

self.pos_nn = MLP([3, 64, out_channels], batch_norm=False)
self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False)

self.attn_nn = MLP([out_channels, 64, out_channels], batch_norm=False)
self.attn_nn = MLP([out_channels, 64, out_channels], norm=None,
plain_last=False)

self.transformer = PointTransformerConv(in_channels, out_channels,
pos_nn=self.pos_nn,
Expand All @@ -55,7 +52,7 @@ def __init__(self, in_channels, out_channels, ratio=0.25, k=16):
super().__init__()
self.k = k
self.ratio = ratio
self.mlp = MLP([in_channels, out_channels])
self.mlp = MLP([in_channels, out_channels], plain_last=False)

def forward(self, x, pos, batch):
# FPS sampling
Expand All @@ -80,14 +77,6 @@ def forward(self, x, pos, batch):
return out, sub_pos, sub_batch


def MLP(channels, batch_norm=True):
return Seq(*[
Seq(Lin(channels[i - 1], channels[i]),
BN(channels[i]) if batch_norm else Identity(), ReLU())
for i in range(1, len(channels))
])


class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels, dim_model, k=16):
super().__init__()
Expand All @@ -97,7 +86,7 @@ def __init__(self, in_channels, out_channels, dim_model, k=16):
in_channels = max(in_channels, 1)

# first block
self.mlp_input = MLP([in_channels, dim_model[0]])
self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)

self.transformer_input = TransformerBlock(in_channels=dim_model[0],
out_channels=dim_model[0])
Expand All @@ -116,8 +105,7 @@ def __init__(self, in_channels, out_channels, dim_model, k=16):
out_channels=dim_model[i + 1]))

# class score computation
self.mlp_output = Seq(Lin(dim_model[-1], 64), ReLU(), Lin(64, 64),
ReLU(), Lin(64, out_channels))
self.mlp_output = MLP([dim_model[-1], 64, out_channels], norm=None)

def forward(self, x, pos, batch=None):

Expand Down
23 changes: 8 additions & 15 deletions examples/point_transformer_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@

import torch
import torch.nn.functional as F
from point_transformer_classification import (
MLP,
TransformerBlock,
TransitionDown,
)
from torch.nn import Linear as Lin
from torch.nn import ReLU
from torch.nn import Sequential as Seq
from point_transformer_classification import TransformerBlock, TransitionDown
from torch_cluster import knn_graph
from torch_scatter import scatter
from torchmetrics.functional import jaccard_index

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn.unpool import knn_interpolate
from torch_geometric.nn import MLP, knn_interpolate

category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
Expand All @@ -43,8 +36,8 @@ class TransitionUp(torch.nn.Module):
'''
def __init__(self, in_channels, out_channels):
super().__init__()
self.mlp_sub = MLP([in_channels, out_channels])
self.mlp = MLP([out_channels, out_channels])
self.mlp_sub = MLP([in_channels, out_channels], plain_last=False)
self.mlp = MLP([out_channels, out_channels], plain_last=False)

def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None):
# transform low-res features and reduce the number of features
Expand All @@ -68,7 +61,7 @@ def __init__(self, in_channels, out_channels, dim_model, k=16):
in_channels = max(in_channels, 1)

# first block
self.mlp_input = MLP([in_channels, dim_model[0]])
self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)

self.transformer_input = TransformerBlock(
in_channels=dim_model[0],
Expand Down Expand Up @@ -102,16 +95,16 @@ def __init__(self, in_channels, out_channels, dim_model, k=16):
out_channels=dim_model[i]))

# summit layers
self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], batch_norm=False)
self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], norm=None,
plain_last=False)

self.transformer_summit = TransformerBlock(
in_channels=dim_model[-1],
out_channels=dim_model[-1],
)

# class score computation
self.mlp_output = Seq(Lin(dim_model[0], 64), ReLU(), Lin(64, 64),
ReLU(), Lin(64, out_channels))
self.mlp_output = MLP([dim_model[0], 64, out_channels], norm=None)

def forward(self, x, pos, batch=None):

Expand Down
2 changes: 1 addition & 1 deletion examples/pointnet2_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self):
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))

self.mlp = MLP([1024, 512, 256, 10], dropout=0.5, batch_norm=False)
self.mlp = MLP([1024, 512, 256, 10], dropout=0.5, norm=None)

def forward(self, data):
sa0_out = (data.x, data.pos, data.batch)
Expand Down
3 changes: 1 addition & 2 deletions examples/pointnet2_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(self, num_classes):
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))

self.mlp = MLP([128, 128, 128, num_classes], dropout=0.5,
batch_norm=False)
self.mlp = MLP([128, 128, 128, num_classes], dropout=0.5, norm=None)

self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_ignite/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, in_channels: int, out_channels: int,
dropout=dropout, jk='cat')

self.classifier = MLP([hidden_channels, hidden_channels, out_channels],
batch_norm=True, dropout=dropout)
norm="batch_norm", dropout=dropout)

def forward(self, data):
x = self.gnn(data.x, data.edge_index)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_lightning/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, in_channels: int, out_channels: int,
dropout=dropout, jk='cat')

self.classifier = MLP([hidden_channels, hidden_channels, out_channels],
batch_norm=True, dropout=dropout)
norm="batch_norm", dropout=dropout)

self.train_acc = Accuracy()
self.val_acc = Accuracy()
Expand Down
2 changes: 1 addition & 1 deletion examples/seal_link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6):
conv1d_kws[1], 1)
dense_dim = int((self.k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, batch_norm=False)
self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, norm=None)

def forward(self, x, edge_index, batch):
xs = [x]
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/dense/dmon_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, channels: Union[int, List[int]], k: int,
channels = [channels]

from torch_geometric.nn.models.mlp import MLP
self.mlp = MLP(channels + [k], act='selu', batch_norm=False)
self.mlp = MLP(channels + [k], act='selu', norm=None)

self.dropout = dropout

Expand Down

0 comments on commit a8601aa

Please sign in to comment.