Skip to content

Commit

Permalink
Merge pull request #25 from hypnopump/main
Browse files Browse the repository at this point in the history
Update notebook and fix errors
  • Loading branch information
hypnopump authored Mar 29, 2021
2 parents bba5ef4 + 97bf550 commit 876f995
Show file tree
Hide file tree
Showing 2 changed files with 637 additions and 251 deletions.
7 changes: 2 additions & 5 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def fourier_encode_dist(x, num_encodings = 4, include_self = True):
def embedd_token(x, dims, layers):
stop_concat = -len(dims)
to_embedd = x[:, stop_concat:].long()
for i,emb_layer in enumerate(layer):
for i,emb_layer in enumerate(layers):
# the portion corresponding to `to_embedd` part gets dropped
x = torch.cat([ x[:, :stop_concat],
emb_layer( to_embedd[:, i] )
Expand Down Expand Up @@ -92,7 +92,6 @@ def __init__(
edge_dim = 0,
m_dim = 16,
fourier_features = 0,
norm_rel_coors = False,
num_nearest_neighbors = 0,
dropout = 0.0,
init_eps = 1e-3,
Expand Down Expand Up @@ -541,8 +540,6 @@ def __init__(self, n_layers, feats_dim, pos_dim = 3,
self.norm_feats = norm_feats
self.update_feats = update_feats
self.update_coors = update_coors
self.norm_rel_coors = norm_rel_coors
self.norm_coor_weights= norm_coor_weights
self.recalc = recalc

# instantiate layers
Expand Down Expand Up @@ -574,7 +571,7 @@ def forward(self, x, edge_index, batch, edge_attr,

# EDGES - Embedd each dim to its target dimensions:
if edges_need_embedding:
edge_attr = embedd_token(x, self.edge_embedding_dims, self.edge_emb_layers)
edge_attr = embedd_token(edge_attr, self.edge_embedding_dims, self.edge_emb_layers)
edges_need_embedding = False

# pass layers
Expand Down
Loading

0 comments on commit 876f995

Please sign in to comment.