Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention.
$ pip install En-transformer
import torch
from en_transformer import EnTransformer
model = EnTransformer(
dim = 512,
depth = 4, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
edge_dim = 4, # dimension of edge feature
neighbors = 64, # only do attention between coordinates N nearest neighbors - set to 0 to turn off
talking_heads = True, # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
checkpoint = True, # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
use_cross_product = True # use cross product vectors (idea by @MattMcPartlon)
)
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)
mask = torch.ones(1, 1024).bool()
feats, coors = model(feats, coors, edges, mask = mask) # (1, 16, 512), (1, 16, 3)
Letting the network take care of both atomic and bond type embeddings
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10, # number of unique nodes, say atoms
rel_pos_emb = True, # set this to true if your sequence is not an unordered set. it will accelerate convergence
num_edge_tokens = 5, # number of unique edges, say bond types
dim = 128,
edge_dim = 16,
depth = 3,
heads = 4,
dim_head = 32,
neighbors = 8
)
atoms = torch.randint(0, 10, (1, 16)) # 10 different types of atoms
bonds = torch.randint(0, 5, (1, 16, 16)) # 5 different types of bonds (n x n)
coors = torch.randn(1, 16, 3) # atomic spatial coordinates
feats_out, coors_out = model(atoms, coors, edges = bonds) # (1, 16, 512), (1, 16, 3)
If you would like to only attend to sparse neighbors, as defined by an adjacency matrix (say for atoms), you have to set one more flag and then pass in the N x N
adjacency matrix.
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10,
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
neighbors = 0,
only_sparse_neighbors = True, # must be set to true
num_adj_degrees = 2, # the number of degrees to derive from 1st degree neighbors passed in
adj_dim = 8 # whether to pass the adjacency degree information as an edge embedding
)
atoms = torch.randint(0, 10, (1, 16))
coors = torch.randn(1, 16, 3)
# naively assume a single chain of atoms
i = torch.arange(atoms.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
# adjacency matrix must be passed in
feats_out, coors_out = model(atoms, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)
If you need to pass in continuous edges
import torch
from en_transformer import EnTransformer
from en_transformer.utils import rot
model = EnTransformer(
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
edge_dim = 4,
num_nearest_neighbors = 0,
only_sparse_neighbors = True
)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
feats1, coors1 = model(feats, coors, adj_mat = adj_mat, edges = edges)
To run a protein backbone coordinate denoising toy task, first install sidechainnet
$ pip install sidechainnet
Then
$ python denoise.py
- add https://arxiv.org/abs/2112.05682 for researchers to stretch to even bigger molecules
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
title = {Talking-Heads Attention},
author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
year = {2020},
eprint = {2003.02436},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}