diff --git a/README.md b/README.md index 9372f57..1d4662f 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,36 @@ adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3) ``` +## Edges + +If you need to pass in continuous edges + +```python +import torch +from egnn_pytorch.egnn_pytorch import EGNN_Network + +net = EGNN_Network( + num_tokens = 21, + dim = 32, + depth = 3, + edge_dim = 4, + num_nearest_neighbors = 3 +) + +feats = torch.randint(0, 21, (1, 1024)) +coors = torch.randn(1, 1024, 3) +mask = torch.ones_like(feats).bool() + +continuous_edges = torch.randn(1, 1024, 1024, 4) + +# naive adjacency matrix +# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024) +i = torch.arange(1024) +adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) + +feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3) +``` + ## Examples To run the protein backbone denoising example, first install `sidechainnet` diff --git a/egnn_pytorch/egnn_pytorch.py b/egnn_pytorch/egnn_pytorch.py index 5cad21d..85ac833 100644 --- a/egnn_pytorch/egnn_pytorch.py +++ b/egnn_pytorch/egnn_pytorch.py @@ -281,11 +281,12 @@ def __init__( self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None + self.has_edges = edge_dim > 0 self.num_adj_degrees = num_adj_degrees self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None - edge_dim = edge_dim if exists(num_edge_tokens) else 0 + edge_dim = edge_dim if self.has_edges else 0 adj_dim = adj_dim if exists(num_adj_degrees) else 0 self.layers = nn.ModuleList([]) diff --git a/setup.py b/setup.py index 3120ba5..2009c62 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'egnn-pytorch', packages = find_packages(), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'E(n)-Equivariant Graph Neural Network - Pytorch', author = 'Phil Wang, Eric Alcaide',