Skip to content

Commit 2c0acd6

Browse files
committed
feat(pytorch-geometric): add a GNN train demo for geometric
1 parent ab16007 commit 2c0acd6

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
测试脚本,下载数据集默认会从github上下载,可能耗时较长。
3+
```bash
4+
python test_train_GNN_demo.py
5+
```
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import torch_musa
3+
from torch import Tensor
4+
from torch_geometric.nn import GCNConv
5+
from torch_geometric.datasets import Planetoid
6+
import torch.nn.functional as F
7+
8+
dataset = Planetoid(root='.', name='Cora')
9+
10+
class GCN(torch.nn.Module):
11+
def __init__(self, in_channels, hidden_channels, out_channels):
12+
super().__init__()
13+
self.conv1 = GCNConv(in_channels, hidden_channels)
14+
self.conv2 = GCNConv(hidden_channels, out_channels)
15+
16+
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
17+
# x: Node feature matrix of shape [num_nodes, in_channels]
18+
# edge_index: Graph connectivity matrix of shape [2, num_edges]
19+
x = self.conv1(x, edge_index).relu()
20+
x = self.conv2(x, edge_index)
21+
return x
22+
23+
model = GCN(dataset.num_features, 16, dataset.num_classes).to('musa')
24+
25+
data = dataset[0]
26+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
27+
28+
for epoch in range(200):
29+
pred = model(data.x.to('musa'), data.edge_index.to('musa'))
30+
loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask].to('musa'))
31+
32+
# Backpropagation
33+
optimizer.zero_grad()
34+
loss.backward()
35+
optimizer.step()

0 commit comments

Comments
 (0)