diff --git a/evaluation/optimizer.py b/evaluation/optimizer.py new file mode 100644 index 0000000..1ca7dd5 --- /dev/null +++ b/evaluation/optimizer.py @@ -0,0 +1,37 @@ +import torch +from torch import nn +from easydict import EasyDict +import yaml +config = './configs/train_res.yml' +config = load_config(config) + +def load_config(path): + with open(path, 'r') as f: + return EasyDict(yaml.safe_load(f)) + +def get_optimizer(cfg, model): + if cfg.type == 'adam': + return torch.optim.Adam( + model.parameters(), + lr=cfg.lr, + weight_decay=cfg.weight_decay, + betas=(cfg.beta1, cfg.beta2, ) + ) + else: + raise NotImplementedError('Optimizer not supported: %s' % cfg.type) + + +def get_scheduler(cfg, optimizer): + if cfg.type == 'plateau': + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=cfg.factor, + patience=cfg.patience, + min_lr=cfg.min_lr + ) + else: + raise NotImplementedError('Scheduler not supported: %s' % cfg.type) + +model = nn.Linear(12,24) +optimizer = get_optimizer(config.train.optimizer, model) +scheduler = get_scheduler(config.train.scheduler, optimizer) \ No newline at end of file diff --git a/models/interaction/geoattn.py b/models/interaction/geoattn.py index e2f1135..eb24c47 100644 --- a/models/interaction/geoattn.py +++ b/models/interaction/geoattn.py @@ -1,84 +1,157 @@ -# suppose: +# supnode_pose: # node_sca dim=4, edge_sca_dim=6 # node_vec_dim = 3, edge_vec_dim = 3 from torch import nn from torch_scatter import scatter_sum -from models.invariant import VNLinear, GVPerceptronVN +from ..invariant import VNLinear, GVPerceptronVN import torch from torch_scatter import scatter_softmax from torch.nn import Sigmoid +from .. model_utils import GaussianSmearing + +class EdgeMapping(nn.Module): + def __init__(self, edge_channels): + super().__init__() + self.nn = nn.Linear(in_features=1, out_features=edge_channels, bias=False) + + def forward(self, edge_vector): + edge_vector = edge_vector / (torch.norm(edge_vector, p=2, dim=1, keepdim=True)+1e-7) + expansion = self.nn(edge_vector.unsqueeze(-1)).transpose(1, -1) + return expansion class Geoattn_GNN(nn.Module): - def __init__(self, input_node_vec_dim=2, node_vec_dim=3,input_node_sca_dim=13, \ - input_edge_vec_dim = 1, input_edge_sca_dim=4, out_dim=16, normalize=20.): + def __init__(self, node_sca_dim=256,node_vec_dim=64, num_edge_types=4, edge_dim=64, hid_dim=128,\ + out_sca_dim=256, out_vec_dim=64, cutoff=10): super().__init__() - # To simplify the model, the out_feats_dim of edges and nodes are the same - - ### vector feature mapping - self.node_vec_net = VNLinear(input_node_vec_dim,out_dim) - self.node_sca_net = nn.Linear(input_node_sca_dim, out_dim) - self.edge_vec_net = VNLinear(input_edge_vec_dim, out_dim) - self.edge_sca_net = nn.Linear(input_edge_sca_dim, out_dim) - ### scalar feature mapping - self.node_net = nn.Linear(input_node_sca_dim, out_dim) - self.edge_net = nn.Linear(input_edge_sca_dim, out_dim) + self.cutoff = cutoff + self.edge_expansion = EdgeMapping(edge_dim) + self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_dim - num_edge_types) + self.node_mapper = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim) + self.edge_mapper = GVLinear(edge_dim,edge_dim,node_sca_dim,node_vec_dim) + + self.edge_net = nn.Linear(node_sca_dim, hid_dim) + self.node_net = nn.Linear(node_sca_dim, hid_dim) + + self.edge_sca_net = nn.Linear(node_sca_dim, hid_dim) + self.node_sca_net = nn.Linear(node_sca_dim, hid_dim) + self.edge_vec_net = VNLinear(node_vec_dim, hid_dim) + self.node_vec_net = VNLinear(node_vec_dim, hid_dim) + - ### the input dimension of sca_attn is (node_j||node_i||dist_ij) - self.sca_attn_net = nn.Linear(input_node_sca_dim*2+1, out_dim) - self.vec_attn_net = nn.Linear(input_node_vec_dim, out_dim) + self.sca_attn_net = nn.Linear(node_sca_dim*2+1, hid_dim) + self.vec_attn_net = VNLinear(node_vec_dim, hid_dim) self.softmax = scatter_softmax - self.sigmoid = Sigmoid() + self.sigmoid = nn.Sigmoid() - self.mapper = GVPerceptronVN(out_dim,out_dim,out_dim,out_dim) + self.msg_out = GVLinear(hid_dim, hid_dim, out_sca_dim, out_vec_dim) + + self.resi_connecter = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim) + self.aggr_out = GVPerceptronVN(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim) - def forward(self, node_feats, edge_feats, edge_index, node_pos): - - edge_dist = torch.norm(pos[edge_index[0]]-pos[edge_index[1]], dim=-1) - node_sca, node_vec = node_feats - edge_sca, edge_vec = edge_feats - edge_index_raw = edge_index[0] + def forward(self, node_feats, edge_feature, edge_vector, edge_index, node_pos): + num_nodes = node_pos.shape[0] + edge_dist = torch.norm(node_pos[edge_index[0]]-node_pos[edge_index[1]], dim=-1) + edge_index_row = edge_index[0] + + ## map edge_features: original space -> interation space + edge_dist = torch.norm(edge_vector, dim=-1, p=2) + edge_sca_feat = torch.cat([self.distance_expansion(edge_dist), edge_feature], dim=-1) + edge_vec_feat = self.edge_expansion(edge_vector) + + # message passing framework + ## extract edge and node features in interaction space + node_sca_feats, node_vec_feats = self.node_mapper(node_feats) + edge_sca_feat, edge_vec_feat = self.edge_mapper([edge_sca_feat, edge_vec_feat]) + node_sca_feats, node_vec_feats = node_sca_feats[edge_index_row], node_vec_feats[edge_index_row] - # compute the attention score \alpha_ij and A_ij - alpha_sca = torch.cat([node_sca[edge_index[0]], node_sca[edge_index[1]], edge_dist.unsqueeze(-1)], dim=-1) + ## compute the attention score \alpha_ij and A_ij + alpha_sca = torch.cat([node_sca_feats[edge_index[0]], node_sca_feats[edge_index[1]], edge_dist.unsqueeze(-1)], dim=-1) alpha_sca = self.sca_attn_net(alpha_sca) - alpha_sca = self.softmax(alpha_sca,edge_index_raw,dim=0) + alpha_sca = self.softmax(alpha_sca,edge_index_row,dim=0) - alpha_vec_hid = self.vec_attn_net(node_vec) + alpha_vec_hid = self.vec_attn_net(node_vec_feats) alpha_vec = (alpha_vec_hid[edge_index[0]] * alpha_vec_hid[edge_index[1]]).sum(-1).sum(-1) alpha_vec = self.sigmoid(alpha_vec) + + ## message: the scalar feats + node_sca_feat = self.node_net(node_sca_feats)[edge_index_row] * self.edge_net(edge_sca_feat) + ## message: the equivariant interaction between node feature and edge feature + node_sca_hid = self.node_sca_net(node_sca_feats)[edge_index_row].unsqueeze(-1) + edge_vec_hid = self.edge_vec_net(edge_vec_feat) + node_vec_hid = self.node_vec_net(node_vec_feats)[edge_index_row] + edge_sca_hid = self.edge_sca_net(edge_sca_feat).unsqueeze(-1) + msg_sca = node_sca_feat * alpha_sca + msg_vec = (node_sca_hid * edge_vec_hid + node_vec_hid*edge_sca_hid)*alpha_vec.unsqueeze(-1).unsqueeze(-1) + msg_sca,msg_vec = self.msg_out([msg_sca,msg_vec]) - # the scalar feats - node_sca_feat = node_net(node_sca)[edge_index_raw] * edge_net(edge_sca) * alpha_sca - # the equivariant interaction between node feature and edge feature - node_sca_hid = self.node_sca_net(node_sca)[edge_index_raw].unsqueeze(-1) - edge_vec_hid = self.edge_vec_net(edge_vec) - node_vec_hid = self.node_vec_net(node_vec)[edge_index_raw] - edge_sca_hid = self.edge_sca_net(edge_sca).unsqueeze(-1) + ## aggregate the message + aggr_msg_sca = scatter_sum(msg_sca, edge_index_row, dim=0, dim_size=num_nodes) + aggr_msg_vec = scatter_sum(msg_vec, edge_index_row, dim=0, dim_size=num_nodes) + + ## residue connection + resi_sca, resi_vec = self.resi_connecter(node_feats) + out_sca = resi_sca + aggr_msg_sca + out_vec = resi_vec + aggr_msg_vec + + ## map the aggregated feature + out_sca, out_vec = self.aggr_out([out_sca, out_vec]) + + return [out_sca, out_vec] + + +# class Geoattn_GNN(nn.Module): +# def __init__(self, input_node_vec_dim=2, node_vec_dim=3,input_node_sca_dim=13, \ +# input_edge_vec_dim = 1, input_edge_sca_dim=4, out_dim=16, normalize=20.): +# super().__init__() +# # To simplify the model, the out_feats_dim of edges and nodes are the same + +# ### vector feature mapping +# self.node_vec_net = VNLinear(input_node_vec_dim,out_dim) +# self.node_sca_net = nn.Linear(input_node_sca_dim, out_dim) +# self.edge_vec_net = VNLinear(input_edge_vec_dim, out_dim) +# self.edge_sca_net = nn.Linear(input_edge_sca_dim, out_dim) - emb_sca = scatter_sum(node_sca_feat,edge_index_raw, dim=0) - emb_vec = scatter_sum((node_sca_hid * edge_vec_hid + node_vec_hid*edge_sca_hid)*alpha_vec.unsqueeze(-1).unsqueeze(-1), edge_index_raw, dim=0) +# ### scalar feature mapping +# self.node_net = nn.Linear(input_node_sca_dim, out_dim) +# self.edge_net = nn.Linear(input_edge_sca_dim, out_dim) + +# ### the input dimension of sca_attn is (node_j||node_i||dist_ij) +# self.sca_attn_net = nn.Linear(input_node_sca_dim*2+1, out_dim) +# self.vec_attn_net = nn.Linear(input_node_vec_dim, out_dim) +# self.softmax = scatter_softmax +# self.sigmoid = Sigmoid() + +# self.mapper = GVPerceptronVN(out_dim,out_dim,out_dim,out_dim) + +# def forward(self, node_feats, edge_feats, edge_index, node_node_pos): - ### perform the non-linear transformation between scalar feature and vector feature - out = self.mapper([emb_sca,emb_vec]) - - return out - - - # class AtomEmbedding(Module): - # def __init__(self, in_scalar, in_vector, - # out_scalar, out_vector, vector_normalizer=20.): - # super().__init__() - # assert in_vector == 1 - # self.in_scalar = in_scalar - # self.vector_normalizer = vector_normalizer - # self.emb_sca = Linear(in_scalar, out_scalar) - # self.emb_vec = Linear(in_vector, out_vector) - - # def forward(self, scalar_input, vector_input): - # vector_input = vector_input / self.vector_normalizer - # assert vector_input.shape[1:] == (3, ), 'Not support. Only one vector can be input' - # sca_emb = self.emb_sca(scalar_input[:, :self.in_scalar]) # b, f -> b, f' - # vec_emb = vector_input.unsqueeze(-1) # b, 3 -> b, 3, 1 - # vec_emb = self.emb_vec(vec_emb).transpose(1, -1) # b, 1, 3 -> b, f', 3 - # return sca_emb, vec_emb \ No newline at end of file +# edge_dist = torch.norm(node_pos[edge_index[0]]-node_pos[edge_index[1]], dim=-1) +# node_sca, node_vec = node_feats +# edge_sca, edge_vec = edge_feats +# edge_index_raw = edge_index[0] + +# # compute the attention score \alpha_ij and A_ij +# alpha_sca = torch.cat([node_sca[edge_index[0]], node_sca[edge_index[1]], edge_dist.unsqueeze(-1)], dim=-1) +# alpha_sca = self.sca_attn_net(alpha_sca) +# alpha_sca = self.softmax(alpha_sca,edge_index_raw,dim=0) + +# alpha_vec_hid = self.vec_attn_net(node_vec) +# alpha_vec = (alpha_vec_hid[edge_index[0]] * alpha_vec_hid[edge_index[1]]).sum(-1).sum(-1) +# alpha_vec = self.sigmoid(alpha_vec) + +# # the scalar feats +# node_sca_feat = self/node_net(node_sca)[edge_index_raw] * edge_net(edge_sca) * alpha_sca +# # the equivariant interaction between node feature and edge feature +# node_sca_hid = self.node_sca_net(node_sca)[edge_index_raw].unsqueeze(-1) +# edge_vec_hid = self.edge_vec_net(edge_vec) +# node_vec_hid = self.node_vec_net(node_vec)[edge_index_raw] +# edge_sca_hid = self.edge_sca_net(edge_sca).unsqueeze(-1) + +# emb_sca = scatter_sum(node_sca_feat,edge_index_raw, dim=0) +# emb_vec = scatter_sum((node_sca_hid * edge_vec_hid + node_vec_hid*edge_sca_hid)*alpha_vec.unsqueeze(-1).unsqueeze(-1), edge_index_raw, dim=0) + +# ### perform the non-linear transformation between scalar feature and vector feature +# out = self.mapper([emb_sca,emb_vec]) + +# return out diff --git a/models/interaction/geodesic.py b/models/interaction/geodesic.py index adf2ebf..9a24d1d 100644 --- a/models/interaction/geodesic.py +++ b/models/interaction/geodesic.py @@ -4,46 +4,113 @@ from torch import nn from torch_scatter import scatter_sum from math import pi -from models.invariant import VNLinear, GVPerceptronVN +from ..invariant import VNLinear, GVPerceptronVN import torch class Geodesic_GNN(nn.Module): - def __init__(self, node_sca_dim=4, node_vec_dim=3, edge_sca_dim=6, edge_vec_dim=3, out_sca=16, out_vec=16, cutoff=10.): + def __init__(self, node_sca_dim=256, node_vec_dim=64, hid_dim=128, edge_dim=64, num_edge_types=2, \ + out_sca_dim=256, out_vec_dim=64, cutoff = 10.): super().__init__() - # To simplify the model, the out_feats_dim of edges and nodes are the same - - self.edge_sca_sca = nn.Linear(edge_sca_dim, out_sca) - self.node_sca_sca = nn.Linear(node_sca_dim, out_sca) - self.edge_sca_vec = nn.Linear(edge_sca_dim, out_sca) - self.node_sca_vec = nn.Linear(node_sca_dim, out_sca) - self.edge_vec_vec = VNLinear(edge_vec_dim, out_vec) - self.node_vec_vec = VNLinear(node_vec_dim, out_vec) - - self.encoder = GVPerceptronVN(out_sca,out_vec,out_sca,out_sca) + self.edge_expansion = EdgeMapping(edge_dim) + self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_dim - num_edge_types) + + self.node_mapper = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim) + self.edge_mapper = GVLinear(edge_dim,edge_dim,node_sca_dim,node_vec_dim) + + self.edge_sca_sca = nn.Linear(node_sca_dim, hid_dim) + self.node_sca_sca = nn.Linear(node_sca_dim, hid_dim) + + self.edge_sca_vec = nn.Linear(node_sca_dim, hid_dim) + self.node_sca_vec = nn.Linear(node_sca_dim, hid_dim) + self.edge_vec_vec = VNLinear(node_vec_dim, hid_dim) + self.node_vec_vec = VNLinear(node_vec_dim, hid_dim) + + self.msg_out = GVLinear(hid_dim, hid_dim, out_sca_dim, out_vec_dim) + + self.resi_connecter = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim) + self.aggr_out = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim) - def forward(self, node_feats, edge_feats, edge_index, gds_dist): - edge_index_raw = edge_index[0] - num_nodes = node_feats[0].shape[0] - coeff = 0.5 * (torch.cos(gds_dist * pi / self.cutoff) + 1.0) - coeff = coeff * (gds_dist <= self.cutoff) * (gds_dist >= 0.0) - - # compute the scalar message - msg_sca_emb = self.node_sca_sca(node_feats[0])[edge_index_raw] * self.edge_sca_sca(edge_feats[0]) + def forward(self, node_feats, edge_feature, edge_vector edge_index, gds_dist): + + ## map edge_fetures: original space -> interaction space + edge_sca_feat = torch.cat([self.distance_expansion(gds_dist), edge_feature], dim=-1) + edge_vec_feat = self.edge_expansion(edge_vector) + + # Geodesic Message Passing + ## mapping the node and edge features to the same space + edge_index_row = edge_index[0] + node_sca_feats, node_vec_feats = self.node_mapper(node_feats) + edge_sca_feat, edge_vec_feat = self.edge_mapper([edge_sca_feat, edge_vec_feat]) + node_sca_feats, node_vec_feats = node_sca_feats[edge_index_row], node_vec_feats[edge_index_row] + ## geodesic coefficient + coeff = 0.5 * (torch.cos(gds_dist * pi / cutoff) + 1.0) + coeff = coeff * (gds_dist <= cutoff) * (gds_dist >= 0.0) + ## compute the scalar message + msg_sca_emb = self.node_sca_sca(node_sca_feats) * self.edge_sca_sca(edge_sca_feat) msg_sca_emb = msg_sca_emb * coeff.view(-1,1) - - # compute the vector message - msg_vec_emb1 = self.node_vec_vec(node_feats[1])[edge_index_raw] * self.edge_sca_vec(edge_feats[0]).unsqueeze(-1) - msg_vec_emb2 = self.node_sca_vec(node_feats[0])[edge_index_raw].unsqueeze(-1) * self.edge_vec_vec(edge_feats[1]) + + ## compute the vector message + msg_vec_emb1 = self.node_vec_vec(node_vec_feats) * self.edge_sca_vec(edge_sca_feat).unsqueeze(-1) + msg_vec_emb2 = self.node_sca_vec(node_sca_feats).unsqueeze(-1) * self.edge_vec_vec(edge_vec_feat) msg_vec_emb = msg_vec_emb1 + msg_vec_emb2 msg_vec_emb = msg_vec_emb * coeff.view(-1,1,1) + ## message pssing mapping + msg_sca_emb, msg_vec_emb = self.msg_out([msg_sca_emb, msg_vec_emb]) + + ## aggregate the message + aggr_msg_sca = scatter_sum(msg_sca_emb, edge_index_row, dim=0, dim_size=num_nodes) + aggr_msg_vec = scatter_sum(msg_vec_emb, edge_index_row, dim=0, dim_size=num_nodes) + + ## residue connection + resi_sca, resi_vec = self.resi_connecter(node_feats) + out_sca = resi_sca + aggr_msg_sca + out_vec = resi_vec + aggr_msg_vec + + ## aggregation mapper + out_sca, out_vec = self.aggr_out([out_sca, out_vec]) + + return [out_sca, out_vec] + + + +# class Geodesic_GNN(nn.Module): +# def __init__(self, node_sca_dim=4, node_vec_dim=3, edge_sca_dim=6, edge_vec_dim=3, out_sca=16, out_vec=16, cutoff=10.): +# super().__init__() +# # To simplify the model, the out_feats_dim of edges and nodes are the same + +# self.self.edge_sca_sca = nn.Linear(edge_sca_dim, out_sca) +# self.self.node_sca_sca = nn.Linear(node_sca_dim, out_sca) + +# self.self.edge_sca_vec = nn.Linear(edge_sca_dim, out_sca) +# self.self.node_sca_vec = nn.Linear(node_sca_dim, out_sca) +# self.self.edge_vec_vec = VNLinear(edge_vec_dim, out_vec) +# self.self.node_vec_vec = VNLinear(node_vec_dim, out_vec) + +# self.encoder = GVPerceptronVN(out_sca,out_vec,out_sca,out_sca) + +# def forward(self, node_feats, edge_feats, edge_index, gds_dist): +# edge_index_raw = edge_index[0] +# num_nodes = node_feats[0].shape[0] +# coeff = 0.5 * (torch.cos(gds_dist * pi / self.cutoff) + 1.0) +# coeff = coeff * (gds_dist <= self.cutoff) * (gds_dist >= 0.0) + +# # compute the scalar message +# msg_sca_emb = self.self.node_sca_sca(node_feats[0])[edge_index_raw] * self.self.edge_sca_sca(edge_feats[0]) +# msg_sca_emb = msg_sca_emb * coeff.view(-1,1) + +# # compute the vector message +# msg_vec_emb1 = self.self.node_vec_vec(node_feats[1])[edge_index_raw] * self.self.edge_sca_vec(edge_feats[0]).unsqueeze(-1) +# msg_vec_emb2 = self.self.node_sca_vec(node_feats[0])[edge_index_raw].unsqueeze(-1) * self.self.edge_vec_vec(edge_feats[1]) +# msg_vec_emb = msg_vec_emb1 + msg_vec_emb2 +# msg_vec_emb = msg_vec_emb * coeff.view(-1,1,1) - # arrgrate the message - aggr_msg_sca = scatter_sum(msg_sca_emb, edge_index_raw, dim=0, dim_size=num_nodes) - aggr_msg_vec = scatter_sum(msg_vec_emb, edge_index_raw, dim=0, dim_size=num_nodes) +# # arrgrate the message +# aggr_msg_sca = scatter_sum(msg_sca_emb, edge_index_raw, dim=0, dim_size=num_nodes) +# aggr_msg_vec = scatter_sum(msg_vec_emb, edge_index_raw, dim=0, dim_size=num_nodes) - # then encode the geodesic feates - node_aggr_sca, node_aggr_vec = self.encoder((aggr_msg_sca,aggr_msg_vec)) +# # then encode the geodesic feates +# node_aggr_sca, node_aggr_vec = self.encoder((aggr_msg_sca,aggr_msg_vec)) - return node_aggr_sca, node_aggr_vec \ No newline at end of file +# return node_aggr_sca, node_aggr_vec \ No newline at end of file