Skip to content

Commit

Permalink
work without edges
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 26, 2021
1 parent 962a91c commit 369fce0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ $ pip install egnn-pytorch
import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)
```

With edges

```python
import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

Expand Down
10 changes: 8 additions & 2 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
from torch import nn, einsum
from einops import rearrange, repeat

def exists(val):
return val is not None

class EGNN(nn.Module):
def __init__(
self,
dim,
edge_dim,
edge_dim = 0,
m_dim = 16
):
super().__init__()
Expand Down Expand Up @@ -38,7 +41,10 @@ def forward(self, feats, coors, edges = None):

feats_i = repeat(feats, 'b i d -> b i n d', n = n)
feats_j = repeat(feats, 'b j d -> b n j d', n = n)
edge_input = torch.cat((feats_i, feats_j, rel_dist, edges), dim = -1)
edge_input = torch.cat((feats_i, feats_j, rel_dist), dim = -1)

if exists(edges):
edge_input = torch.cat((edge_input, edges), dim = -1)

m_ij = self.edge_mlp(edge_input)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 369fce0

Please sign in to comment.