Skip to content

Commit

Permalink
Add an example PMLP on Cora dataset (#7543)
Browse files Browse the repository at this point in the history
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
n-patricia and rusty1s authored Jun 12, 2023
1 parent 09cb440 commit 0549077
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441))
- Added a sparse `cross_entropy` implementation ([#7447](https://github.com/pyg-team/pytorch_geometric/pull/7447), [#7466](https://github.com/pyg-team/pytorch_geometric/pull/7466))
- Added the `MovieLens-100K` heterogeneous dataset ([#7398](https://github.com/pyg-team/pytorch_geometric/pull/7398))
- Added the `PMLP` model ([#7370](https://github.com/pyg-team/pytorch_geometric/pull/7370))
- Added the `PMLP` model and an example ([#7370](https://github.com/pyg-team/pytorch_geometric/pull/7370), [#7543](https://github.com/pyg-team/pytorch_geometric/pull/7543))
- Added padding capabilities to `HeteroData.to_homogeneous()` in case feature dimensionalities do not match ([#7374](https://github.com/pyg-team/pytorch_geometric/pull/7374))
- Added an optional `batch_size` argument to `fps`, `knn`, `knn_graph`, `radius` and `radius_graph` ([#7368](https://github.com/pyg-team/pytorch_geometric/pull/7368))
- Added `PrefetchLoader` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376), [#7378](https://github.com/pyg-team/pytorch_geometric/pull/7378), [#7383](https://github.com/pyg-team/pytorch_geometric/pull/7383))
Expand Down
58 changes: 58 additions & 0 deletions examples/pmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import PMLP

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())
data = dataset[0].to(device)

model = PMLP(
in_channels=dataset.num_features,
hidden_channels=16,
out_channels=dataset.num_classes,
num_layers=2,
dropout=0.5,
norm=False,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4, lr=0.01)


def train():
model.train()
optimizer.zero_grad()
out = model(data.x) # MLP during training.
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return float(loss)


@torch.no_grad()
def test():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=-1)

accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
return accs


best_val_acc = final_test_acc = 0
for epoch in range(1, 201):
loss = train()
train_acc, val_acc, tmp_test_acc = test()
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')
15 changes: 12 additions & 3 deletions torch_geometric/nn/models/pmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class PMLP(torch.nn.Module):
num_layers (int): The number of layers.
dropout (float, optional): Dropout probability of each hidden
embedding. (default: :obj:`0.`)
norm (bool, optional): If set to :obj:`False`, will not apply batch
normalization. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the module
will not learn additive biases. (default: :obj:`True`)
"""
Expand All @@ -32,6 +34,7 @@ def __init__(
out_channels: int,
num_layers: int,
dropout: float = 0.,
norm: bool = True,
bias: bool = True,
):
super().__init__()
Expand All @@ -50,8 +53,13 @@ def __init__(
self.lins.append(lin)
self.lins.append(Linear(hidden_channels, out_channels, self.bias))

self.norm = torch.nn.BatchNorm1d(hidden_channels, affine=False,
track_running_stats=False)
self.norm = None
if norm:
self.norm = torch.nn.BatchNorm1d(
hidden_channels,
affine=False,
track_running_stats=False,
)

self.conv = SimpleConv(aggr='mean', combine_root='self_loop')

Expand Down Expand Up @@ -81,7 +89,8 @@ def forward(
if self.bias:
x = x + self.lins[i].bias
if i != self.num_layers - 1:
x = self.norm(x)
if self.norm is not None:
x = self.norm(x)
x = x.relu()
x = F.dropout(x, p=self.dropout, training=self.training)

Expand Down

0 comments on commit 0549077

Please sign in to comment.