Skip to content

Commit

Permalink
adding pyg
Browse files Browse the repository at this point in the history
  • Loading branch information
hypnopump committed Feb 27, 2021
1 parent 0fbf3cc commit 63d8997
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
99 changes: 99 additions & 0 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from torch_geometric.nn import MessagePassing
# types
from typing import Optional, List, Union
from torch_geometric.typing import Adj, Size, OptTensor, Tensor

# helper functions

def exists(val):
Expand Down Expand Up @@ -82,6 +87,100 @@ def forward(self, feats, coors, edges = None):

return hidden_out, coors_out


class EGNN_sparse(MessagePassing):
def __init__(
self,
feats_dim,
pos_dim = 3,
edge_attr_dim = 0,
m_dim = 16,
fourier_features = 0
):
super().__init__()
self.fourier_features = fourier_features
self.pos_dim = pos_dim

edge_input_dim = (fourier_features * 2) + (feats_dim * 2) + edge_attr_dim + 1

self.edge_mlp = nn.Sequential(
nn.Linear(edge_input_dim, edge_input_dim * 2),
nn.ReLU(),
nn.Linear(edge_input_dim * 2, m_dim)
)

self.coors_mlp = nn.Sequential(
nn.Linear(m_dim, m_dim * 4),
nn.ReLU(),
nn.Linear(m_dim * 4, 1)
)

self.hidden_mlp = nn.Sequential(
nn.Linear(feats_dim + m_dim, feats_dim * 2),
nn.ReLU(),
nn.Linear(feats_dim * 2, feats_dim),
)

def forward(self, x: Tensor, edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
""" Inputs:
* x: (n_points, d) where d is pos_dims + feat_dims
* edge_attr: tensor (n_edges, n_feats) excluding basic distance feats.
"""
coors, x = x[:, :self.pos_dim], x[:, self.pos_dim:]

rel_coors = coors[edge_index[0]] - coors[edge_index[1]]
rel_dist = torch.norm(rel_coors, dim=-1, keepdim=True)

if self.fourier_features > 0:
rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features)

if edge_attr is None:
edge_attr = torch.cat([edge_attr, rel_dist], dim=-1)
else:
edge_attr = rel_dist

coors_out, hidden_out = self.propagate(edge_index, x=x, edge_attr=edge_attr,
coors=coors, rel_coors=rel_coors)
return torch.cat([coors_out, hidden_out], dim=-1)


def message(self, x_i, x_j, edge_attr) -> Tensor:
m_ij = self.edge_mlp( torch.cat([x_i, x_j, edge_attr], dim=-1) )
coor_w = self.coors_mlp(m_ij)
return m_ij, coor_w

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
"""The initial call to start propagating messages.
Args:
`edge_index` holds the indices of a general (sparse)
assignment matrix of shape :obj:`[N, M]`.
size (tuple, optional) if none, the size will be inferred
and assumed to be quadratic.
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
size = self.__check_input__(edge_index, size)
coll_dict = self.__collect__(self.__user_args__,
edge_index, size, kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
m_ij, coor_wij = self.message(**msg_kwargs)
# aggregate them
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
m_i = self.aggregate(m_ij, **aggr_kwargs)
coor_wi = self.aggregate(coor_wij, **aggr_kwargs)
coor_ri = self.aggregate(kwargs["rel_coors"], **aggr_kwargs)
# return tuple
update_kwargs = self.inspector.distribute('update', coll_dict)
coors_out = kwargs["coors"] + coor_wi + coor_ri
hidden_out = self.hidden_mlp( torch.cat([kwargs["x"], m_i], dim = -1) )

return self.update((hidden_out, coors_out), **update_kwargs)

def __repr__(self):
dict_print = {}
return "E(n)-GNN Layer for Graphs " + str(dict_print)

# attention version

class EGAT(nn.Module):
Expand Down

0 comments on commit 63d8997

Please sign in to comment.