Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 230 additions & 2 deletions Chapter04/04_Graph_Neural_Networks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,234 @@
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the rest of the notebook, we will be performing a similar example as above using other two popular graph-dl frameworks: PyTorch Geometric (PyG) and Deep Graph Library (DGL)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Graph Classification using PyG"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#!pip install fsspec==2024.3.1 # needed for PROTEINS download torch geometric\n",
"#!pip install torch_geometric\n",
"\n",
"import torch\n",
"from torch_geometric.datasets import TUDataset\n",
"from torch_geometric.data import DataLoader\n",
"from torch_geometric.nn import GCNConv, global_mean_pool\n",
"from torch.nn import Linear\n",
"import torch.nn.functional as F\n",
"\n",
"# Load the PROTEINS dataset\n",
"dataset = TUDataset(root='data/PROTEINS', name='PROTEINS')\n",
"\n",
"# Set random seed for reproducibility\n",
"torch.manual_seed(42)\n",
"\n",
"# Shuffle and split the dataset into training and test sets\n",
"dataset = dataset.shuffle()\n",
"split_idx = int(0.8 * len(dataset)) # 80/20 train/test split\n",
"train_dataset = dataset[:split_idx]\n",
"test_dataset = dataset[split_idx:]\n",
"\n",
"# Print dataset statistics\n",
"print(f'Training graphs: {len(train_dataset)}, Test graphs: {len(test_dataset)}')\n",
"\n",
"# Create DataLoader for batching\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
"\n",
"# Define the GCN model\n",
"class GCN(torch.nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, output_dim):\n",
" super(GCN, self).__init__()\n",
" self.conv1 = GCNConv(input_dim, hidden_dim)\n",
" self.conv2 = GCNConv(hidden_dim, hidden_dim)\n",
" self.conv3 = GCNConv(hidden_dim, hidden_dim)\n",
" self.lin = Linear(hidden_dim, output_dim)\n",
" \n",
" def forward(self, x, edge_index, batch):\n",
" # Graph convolution layers with ReLU activations\n",
" x = F.relu(self.conv1(x, edge_index))\n",
" x = F.relu(self.conv2(x, edge_index))\n",
" x = self.conv3(x, edge_index)\n",
" \n",
" # Global pooling to obtain graph-level representation\n",
" x = global_mean_pool(x, batch)\n",
" \n",
" # Apply dropout and final linear layer\n",
" x = F.dropout(x, p=0.5, training=self.training)\n",
" x = self.lin(x)\n",
" return x\n",
"\n",
"# Instantiate the model\n",
"print(dataset.num_node_features)\n",
"model = GCN(input_dim=dataset.num_node_features, hidden_dim=64, output_dim=dataset.num_classes)\n",
"print(model)\n",
"\n",
"# Define optimizer and loss function\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) # Learning rate decay\n",
"\n",
"# Training function\n",
"def train():\n",
" model.train()\n",
" total_loss = 0\n",
" for data in train_loader:\n",
" optimizer.zero_grad()\n",
" out = model(data.x, data.edge_index, data.batch)\n",
" loss = criterion(out, data.y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" return total_loss / len(train_loader)\n",
"\n",
"# Evaluation function\n",
"def evaluate(loader):\n",
" model.eval()\n",
" correct = 0\n",
" for data in loader:\n",
" with torch.no_grad():\n",
" out = model(data.x, data.edge_index, data.batch)\n",
" pred = out.argmax(dim=1)\n",
" correct += int((pred == data.y).sum())\n",
" return correct / len(loader.dataset)\n",
"\n",
"# Training loop\n",
"num_epochs = 200\n",
"for epoch in range(1, num_epochs + 1):\n",
" loss = train()\n",
" train_acc = evaluate(train_loader)\n",
" test_acc = evaluate(test_loader)\n",
" scheduler.step() # Adjust learning rate\n",
"\n",
" print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Graph Classification using DGL"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#!pip install torch==2.1.1 # needed for dgl\n",
"#!pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/repo.html\n",
"\n",
"import dgl\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch.nn import Linear\n",
"from dgl.data import GINDataset\n",
"from dgl.dataloading import GraphDataLoader\n",
"from dgl.nn.pytorch import GraphConv\n",
"from dgl.data.utils import split_dataset\n",
"\n",
"dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)\n",
"\n",
"# Set random seed for reproducibility\n",
"torch.manual_seed(42)\n",
"\n",
"# 2. Split dataset into training and test sets\n",
"train_dataset, val_dataset, test_dataset = split_dataset(dataset, frac_list=[0.8, 0.1, 0.1], shuffle=False, random_state=42)\n",
"\n",
"# Print dataset statistics\n",
"print(f'Training graphs: {len(train_dataset)}, Test graphs: {len(test_dataset)}')\n",
"\n",
"# 3. Create DGL DataLoader for batching\n",
"train_loader = GraphDataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"test_loader = GraphDataLoader(test_dataset, batch_size=64, shuffle=False)\n",
"\n",
"# 4. Define the GCN model using DGL's GraphConv layers\n",
"class GCN(torch.nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, output_dim):\n",
" super(GCN, self).__init__()\n",
" self.conv1 = GraphConv(input_dim, hidden_dim)\n",
" self.conv2 = GraphConv(hidden_dim, hidden_dim)\n",
" self.conv3 = GraphConv(hidden_dim, hidden_dim)\n",
" self.fc = Linear(hidden_dim, output_dim)\n",
"\n",
" def forward(self, g, features):\n",
" # Apply GraphConv layers with ReLU activations\n",
" h = F.relu(self.conv1(g, features))\n",
" h = F.relu(self.conv2(g, h))\n",
" h = self.conv3(g, h)\n",
" \n",
" # Global mean pooling to obtain graph-level representation\n",
" with g.local_scope():\n",
" g.ndata['h'] = h\n",
" hg = dgl.mean_nodes(g, 'h')\n",
" \n",
" # Apply dropout and final linear layer for classification\n",
" hg = F.dropout(hg, p=0.5, training=self.training)\n",
" return self.fc(hg)\n",
"\n",
"# 5. Initialize the model, optimizer, and loss function\n",
"input_dim = dataset.dim_nfeats\n",
"output_dim = dataset.num_classes\n",
"hidden_dim = 64\n",
"\n",
"print(\"Input dim:\", input_dim)\n",
"print(\"Output dim:\", output_dim)\n",
"\n",
"model = GCN(input_dim, hidden_dim, output_dim)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"\n",
"# 6. Training function\n",
"def train():\n",
" model.train()\n",
" total_loss = 0\n",
" for batched_graph, labels in train_loader:\n",
" optimizer.zero_grad()\n",
" features = batched_graph.ndata['attr']\n",
" out = model(batched_graph, features)\n",
" loss = criterion(out, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" return total_loss / len(train_loader)\n",
"\n",
"# 7. Evaluation function\n",
"def evaluate(loader):\n",
" model.eval()\n",
" correct = 0\n",
" for batched_graph, labels in loader:\n",
" features = batched_graph.ndata['attr']\n",
" with torch.no_grad():\n",
" out = model(batched_graph, features)\n",
" pred = out.argmax(dim=1)\n",
" correct += (pred == labels).sum().item()\n",
" return correct / len(loader.dataset)\n",
"\n",
"# 8. Training loop\n",
"num_epochs = 200\n",
"for epoch in range(1, num_epochs + 1):\n",
" loss = train()\n",
" train_acc = evaluate(train_loader)\n",
" test_acc = evaluate(test_loader)\n",
" print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')"
]
}
],
"metadata": {
Expand All @@ -907,7 +1135,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -921,7 +1149,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.0"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down