diff --git a/README.md b/README.md index a282b20..6871207 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ A full EGNN network ```python import torch -from egnn_pytorch.egnn_pytorch import EGNN_Network +from egnn_pytorch import EGNN_Network net = EGNN_Network( num_tokens = 21, @@ -55,7 +55,7 @@ net = EGNN_Network( dim = 32, depth = 3, num_nearest_neighbors = 8, - coor_weights_clamp_value = 2. # absolute clampd value for the coordinate weights, needed if you increase the num neareest neighbors + coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors ) feats = torch.randint(0, 21, (1, 1024)) # (1, 1024) @@ -69,7 +69,7 @@ Only attend to sparse neighbors, given to the network as an adjacency matrix. ```python import torch -from egnn_pytorch.egnn_pytorch import EGNN_Network +from egnn_pytorch import EGNN_Network net = EGNN_Network( num_tokens = 21, @@ -94,7 +94,7 @@ You can also have the network automatically determine the Nth-order neighbors, a ```python import torch -from egnn_pytorch.egnn_pytorch import EGNN_Network +from egnn_pytorch import EGNN_Network net = EGNN_Network( num_tokens = 21, @@ -123,7 +123,7 @@ If you need to pass in continuous edges ```python import torch -from egnn_pytorch.egnn_pytorch import EGNN_Network +from egnn_pytorch import EGNN_Network net = EGNN_Network( num_tokens = 21, @@ -147,6 +147,30 @@ adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3) ``` +## Stability + +The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this. + +```python +import torch +from egnn_pytorch import EGNN_Network + +net = EGNN_Network( + num_tokens = 21, + dim = 32, + depth = 3, + num_nearest_neighbors = 32, + norm_coors = True, # normalize the relative coordinates + coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors +) + +feats = torch.randint(0, 21, (1, 1024)) # (1, 1024) +coors = torch.randn(1, 1024, 3) # (1, 1024, 3) +mask = torch.ones_like(feats).bool() # (1, 1024) + +feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3) +``` + ## Examples To run the protein backbone denoising example, first install `sidechainnet`