Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
Graph Convolutional Unet Architecture based off of https://arxiv.org/abs/1905.05178 on pytorch.

Updated to included optional pooling / non-pooling layers, mult-channel input features, multiple hidden dimension sizes.
  • Loading branch information
rxw497 authored Sep 4, 2024
1 parent 30d8dc6 commit 151635a
Showing 1 changed file with 161 additions and 0 deletions.
161 changes: 161 additions & 0 deletions graphfoundationmodels/rwb-graph_UNET.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# %%
from typing import Callable, List, Union

import torch
from torch import Tensor

from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import OptTensor, PairTensor
from torch_geometric.utils import (
add_self_loops,
remove_self_loops,
to_torch_csr_tensor,
)
from torch_geometric.utils.repeat import repeat


class GraphUNet(torch.nn.Module):
r"""The Graph U-Net model from the `"Graph U-Nets"
<https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like
architecture with graph pooling and unpooling operations.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
out_channels (int): Size of each output sample.
depth (int): The depth of the U-Net architecture.
pool_ratios (float or [float], optional): Graph pooling ratio for each
depth. (default: :obj:`0.5`)
sum_res (bool, optional): If set to :obj:`False`, will use
concatenation for integration of skip connections instead
summation. (default: :obj:`True`)
act (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.nn.functional.relu`)
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
depth: int,
pool_ratios: Union[float, List[float]] = 0.5,
sum_res: bool = True,
act: Union[str, Callable] = 'relu',
):
super().__init__()
assert depth >= 1
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.depth = depth
self.pool_ratios = repeat(pool_ratios, depth)
self.act = activation_resolver(act)
self.sum_res = sum_res

channels = hidden_channels

self.down_convs = torch.nn.ModuleList()
self.pools = torch.nn.ModuleList()
self.down_convs.append(GCNConv(in_channels, channels, improved=True))
for i in range(depth):
self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
self.down_convs.append(GCNConv(channels, channels, improved=True))

in_channels = channels if sum_res else 2 * channels

self.up_convs = torch.nn.ModuleList()
for i in range(depth - 1):
self.up_convs.append(GCNConv(in_channels, channels, improved=True))
self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))

self.reset_parameters()

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for conv in self.down_convs:
conv.reset_parameters()
for pool in self.pools:
pool.reset_parameters()
for conv in self.up_convs:
conv.reset_parameters()

def forward(self, x: Tensor, edge_index: Tensor,
batch: OptTensor = None) -> Tensor:
"""""" # noqa: D419
if batch is None:
batch = edge_index.new_zeros(x.size(0))
edge_weight = x.new_ones(edge_index.size(1))

x = self.down_convs[0](x, edge_index, edge_weight)
x = self.act(x)

xs = [x]
edge_indices = [edge_index]
edge_weights = [edge_weight]
perms = []

for i in range(1, self.depth + 1):
edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
x.size(0))
x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
x, edge_index, edge_weight, batch)

x = self.down_convs[i](x, edge_index, edge_weight)
x = self.act(x)

if i < self.depth:
xs += [x]
edge_indices += [edge_index]
edge_weights += [edge_weight]
perms += [perm]

for i in range(self.depth):
j = self.depth - 1 - i

res = xs[j]
edge_index = edge_indices[j]
edge_weight = edge_weights[j]
perm = perms[j]

up = torch.zeros_like(res)
up[perm] = x
#print(res.shape, j)
#print(x.shape, j)

x = res + up if self.sum_res else torch.cat((res, up), dim=-1)

x = self.up_convs[i](x, edge_index, edge_weight)
x = self.act(x) if i < self.depth - 1 else x

#print(x.shape, j)

return x


def augment_adj(self, edge_index: Tensor, edge_weight: Tensor,
num_nodes: int) -> PairTensor:
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
num_nodes=num_nodes)
adj = to_torch_csr_tensor(edge_index, edge_weight,
size=(num_nodes, num_nodes))
adj = (adj @ adj).to_sparse_coo()
edge_index, edge_weight = adj.indices(), adj.values()
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
return edge_index, edge_weight

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.hidden_channels}, {self.out_channels}, '
f'depth={self.depth}, pool_ratios={self.pool_ratios})')

# %%
pool_ratios = [.99,.99,.5,.99,.5]
pooling_layer = [True,True,True]
hidden_channel_encoder = [2000,1500,1000,500]
model = GraphUNet(in_channels = 2688, hidden_channels = 2688, out_channels = 2688, depth = 2, pool_ratios = pool_ratios, sum_res = True, act = 'relu')

# %%
print(summary(model, torch.randn(90, 2688), edge_index = full_edge_index))
# %%

0 comments on commit 151635a

Please sign in to comment.