From 05490776e576addd4727e0a4bcd18e7cc0a16f3c Mon Sep 17 00:00:00 2001 From: Novi Patricia Date: Mon, 12 Jun 2023 22:13:13 +0700 Subject: [PATCH] Add an example `PMLP` on Cora dataset (#7543) Co-authored-by: rusty1s --- CHANGELOG.md | 2 +- examples/pmlp.py | 58 +++++++++++++++++++++++++++++++ torch_geometric/nn/models/pmlp.py | 15 ++++++-- 3 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 examples/pmlp.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c2fbf105bfa..19326015f17a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441)) - Added a sparse `cross_entropy` implementation ([#7447](https://github.com/pyg-team/pytorch_geometric/pull/7447), [#7466](https://github.com/pyg-team/pytorch_geometric/pull/7466)) - Added the `MovieLens-100K` heterogeneous dataset ([#7398](https://github.com/pyg-team/pytorch_geometric/pull/7398)) -- Added the `PMLP` model ([#7370](https://github.com/pyg-team/pytorch_geometric/pull/7370)) +- Added the `PMLP` model and an example ([#7370](https://github.com/pyg-team/pytorch_geometric/pull/7370), [#7543](https://github.com/pyg-team/pytorch_geometric/pull/7543)) - Added padding capabilities to `HeteroData.to_homogeneous()` in case feature dimensionalities do not match ([#7374](https://github.com/pyg-team/pytorch_geometric/pull/7374)) - Added an optional `batch_size` argument to `fps`, `knn`, `knn_graph`, `radius` and `radius_graph` ([#7368](https://github.com/pyg-team/pytorch_geometric/pull/7368)) - Added `PrefetchLoader` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376), [#7378](https://github.com/pyg-team/pytorch_geometric/pull/7378), [#7383](https://github.com/pyg-team/pytorch_geometric/pull/7383)) diff --git a/examples/pmlp.py b/examples/pmlp.py new file mode 100644 index 000000000000..29088aa8da17 --- /dev/null +++ b/examples/pmlp.py @@ -0,0 +1,58 @@ +import os.path as osp + +import torch +import torch.nn.functional as F + +import torch_geometric.transforms as T +from torch_geometric.datasets import Planetoid +from torch_geometric.nn import PMLP + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') +dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures()) +data = dataset[0].to(device) + +model = PMLP( + in_channels=dataset.num_features, + hidden_channels=16, + out_channels=dataset.num_classes, + num_layers=2, + dropout=0.5, + norm=False, +).to(device) + +optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4, lr=0.01) + + +def train(): + model.train() + optimizer.zero_grad() + out = model(data.x) # MLP during training. + loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + return float(loss) + + +@torch.no_grad() +def test(): + model.eval() + out = model(data.x, data.edge_index) + pred = out.argmax(dim=-1) + + accs = [] + for mask in [data.train_mask, data.val_mask, data.test_mask]: + accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) + return accs + + +best_val_acc = final_test_acc = 0 +for epoch in range(1, 201): + loss = train() + train_acc, val_acc, tmp_test_acc = test() + if val_acc > best_val_acc: + best_val_acc = val_acc + test_acc = tmp_test_acc + print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') diff --git a/torch_geometric/nn/models/pmlp.py b/torch_geometric/nn/models/pmlp.py index 7677e0b3462f..9034a175c106 100644 --- a/torch_geometric/nn/models/pmlp.py +++ b/torch_geometric/nn/models/pmlp.py @@ -22,6 +22,8 @@ class PMLP(torch.nn.Module): num_layers (int): The number of layers. dropout (float, optional): Dropout probability of each hidden embedding. (default: :obj:`0.`) + norm (bool, optional): If set to :obj:`False`, will not apply batch + normalization. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the module will not learn additive biases. (default: :obj:`True`) """ @@ -32,6 +34,7 @@ def __init__( out_channels: int, num_layers: int, dropout: float = 0., + norm: bool = True, bias: bool = True, ): super().__init__() @@ -50,8 +53,13 @@ def __init__( self.lins.append(lin) self.lins.append(Linear(hidden_channels, out_channels, self.bias)) - self.norm = torch.nn.BatchNorm1d(hidden_channels, affine=False, - track_running_stats=False) + self.norm = None + if norm: + self.norm = torch.nn.BatchNorm1d( + hidden_channels, + affine=False, + track_running_stats=False, + ) self.conv = SimpleConv(aggr='mean', combine_root='self_loop') @@ -81,7 +89,8 @@ def forward( if self.bias: x = x + self.lins[i].bias if i != self.num_layers - 1: - x = self.norm(x) + if self.norm is not None: + x = self.norm(x) x = x.relu() x = F.dropout(x, p=self.dropout, training=self.training)