Skip to content

Commit

Permalink
readme
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 15, 2021
1 parent d91e453 commit 706f45f
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ A full EGNN network

```python
import torch
from egnn_pytorch.egnn_pytorch import EGNN_Network
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
num_positions = 1024, # unless what you are passing in is an unordered set, set this to the maximum sequence length
dim = 32,
depth = 3,
num_nearest_neighbors = 8,
coor_weights_clamp_value = 2. # absolute clampd value for the coordinate weights, needed if you increase the num neareest neighbors
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
Expand All @@ -69,7 +69,7 @@ Only attend to sparse neighbors, given to the network as an adjacency matrix.

```python
import torch
from egnn_pytorch.egnn_pytorch import EGNN_Network
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
Expand All @@ -94,7 +94,7 @@ You can also have the network automatically determine the Nth-order neighbors, a

```python
import torch
from egnn_pytorch.egnn_pytorch import EGNN_Network
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
Expand Down Expand Up @@ -123,7 +123,7 @@ If you need to pass in continuous edges

```python
import torch
from egnn_pytorch.egnn_pytorch import EGNN_Network
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
Expand All @@ -147,6 +147,30 @@ adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
```

## Stability

The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this.

```python
import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_nearest_neighbors = 32,
norm_coors = True, # normalize the relative coordinates
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3) # (1, 1024, 3)
mask = torch.ones_like(feats).bool() # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)
```

## Examples

To run the protein backbone denoising example, first install `sidechainnet`
Expand Down

0 comments on commit 706f45f

Please sign in to comment.