Skip to content

Commit

Permalink
make sure continuous edges work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 5, 2021
1 parent 6d423c9 commit 4c216c6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,36 @@ adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
```

## Edges

If you need to pass in continuous edges

```python
import torch
from egnn_pytorch.egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
edge_dim = 4,
num_nearest_neighbors = 3
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

continuous_edges = torch.randn(1, 1024, 1024, 4)

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
```

## Examples

To run the protein backbone denoising example, first install `sidechainnet`
Expand Down
3 changes: 2 additions & 1 deletion egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,12 @@ def __init__(

self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
self.has_edges = edge_dim > 0

self.num_adj_degrees = num_adj_degrees
self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None

edge_dim = edge_dim if exists(num_edge_tokens) else 0
edge_dim = edge_dim if self.has_edges else 0
adj_dim = adj_dim if exists(num_adj_degrees) else 0

self.layers = nn.ModuleList([])
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit 4c216c6

Please sign in to comment.