Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication.
$ pip install egnn-pytorch
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)
With edges
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)
- add integration with pytorch geometric
- add tests for se3 equivariance
- add an EGAT (attention flavored variant)
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}