Skip to content

Commit

Permalink
add automatic conversion of adjacency matrix to n-th degree adjacency…
Browse files Browse the repository at this point in the history
… matrix, and being able to pass the adjacency degree embedding, with two extra keyword arguments
  • Loading branch information
lucidrains committed Mar 27, 2021
1 parent 1bb8841 commit e7a8b8c
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,33 @@ net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_nearest_neighbors = 3,
only_sparse_neighbors = True
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
```

You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments

```python
import torch
from egnn_pytorch.egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_adj_degrees = 3, # fetch up to 3rd degree neighbors
adj_dim = 8, # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
only_sparse_neighbors = True
)

Expand Down
9 changes: 5 additions & 4 deletions denoise_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def cycle(loader, len_thres = 1000):
num_tokens = 21,
depth = 5,
dim = 8,
num_nearest_neighbors = 16,
num_nearest_neighbors = 0,
fourier_features = 2,
only_sparse_neighbors = True,
norm_coors = True
norm_coors = True,
adj_dim = 8,
num_adj_degrees = 4,
only_sparse_neighbors = True
).cuda()

data = scn.load(
Expand Down Expand Up @@ -63,7 +65,6 @@ def cycle(loader, len_thres = 1000):

i = torch.arange(seq.shape[-1], device = seq.device)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 # get second degree neighbors too

noised_coords = coords + torch.randn_like(coords)

Expand Down
39 changes: 35 additions & 4 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def forward(self, feats, coors, edges = None, mask = None, adj_mat = None):
if exists(mask):
num_nodes = mask.sum(dim = -1)

use_nearest = num_nearest > 0
use_nearest = num_nearest > 0 or only_sparse_neighbors

rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d')
rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True)
Expand All @@ -174,16 +174,16 @@ def forward(self, feats, coors, edges = None, mask = None, adj_mat = None):

if exists(adj_mat):
if len(adj_mat.shape) == 2:
adj_mat = repeat(adj_mat, 'i j -> b i j', b = b)
adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

if only_sparse_neighbors:
num_nearest = int(adj_mat.float().sum(dim = -1).max().item())
valid_radius = 0

self_mask = rearrange(torch.eye(n, device = device, dtype = torch.bool), 'i j -> () i j')

adj_mat = adj_mat.masked_fill(self_mask, False)
ranking.masked_fill_(self_mask, -1.)
adj_mat.masked_fill_(self_mask, False)
ranking.masked_fill_(adj_mat, 0.)

nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim = -1, largest = False)
Expand Down Expand Up @@ -272,23 +272,54 @@ def __init__(
num_tokens = None,
num_edge_tokens = None,
edge_dim = 0,
num_adj_degrees = None,
adj_dim = 0,
**kwargs
):
super().__init__()
assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'

self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None

self.num_adj_degrees = num_adj_degrees
self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None

edge_dim = edge_dim if exists(num_edge_tokens) else 0
adj_dim = adj_dim if exists(num_adj_degrees) else 0

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(EGNN(dim = dim, edge_dim = edge_dim, norm_feats = True, **kwargs))
self.layers.append(EGNN(dim = dim, edge_dim = (edge_dim + adj_dim), norm_feats = True, **kwargs))

def forward(self, feats, coors, adj_mat = None, edges = None, mask = None):
b = feats.shape[0]

if exists(self.token_emb):
feats = self.token_emb(feats)

if exists(edges) and exists(self.edge_emb):
edges = self.edge_emb(edges)

# create N-degrees adjacent matrix from 1st degree connections
if exists(self.num_adj_degrees):
if len(adj_mat.shape) == 2:
adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

adj_indices = adj_mat.clone().long()

for ind in range(self.num_adj_degrees - 1):
degree = ind + 2

next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
adj_indices.masked_fill_(next_degree_mask, degree)
adj_mat = next_degree_adj_mat.clone()

if exists(self.adj_emb):
adj_emb = self.adj_emb(adj_indices)
edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

for layer in self.layers:
feats, coors = layer(feats, coors, adj_mat = adj_mat, edges = edges, mask = mask)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.0.46',
version = '0.1.0',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit e7a8b8c

Please sign in to comment.