From 13926ad59446ab1c8bdf0354845028d3d30cff8f Mon Sep 17 00:00:00 2001 From: peastman Date: Wed, 21 Feb 2024 16:11:51 -0800 Subject: [PATCH] Fixed test failures --- torchmdnet/models/tensornet.py | 2 +- torchmdnet/models/torchmd_et.py | 2 +- torchmdnet/models/torchmd_gn.py | 2 +- torchmdnet/models/torchmd_t.py | 2 +- torchmdnet/optimize.py | 5 +++++ 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 3a40ec6f1..70476170c 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -234,7 +234,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 6fe211b24..5545f84ba 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -206,7 +206,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) if self.reshape_embedding is not None: diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 690e34e70..f7760bc2e 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -208,7 +208,7 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) if self.reshape_embedding is not None: diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 89655740f..9ab6f2bd2 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -202,7 +202,7 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) if self.reshape_embedding is not None: diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 0c7f56513..6cae39635 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -33,6 +33,7 @@ def __init__(self, model): super().__init__() self.model = model + self.extra_embedding = model.extra_embedding self.neighbors = CFConvNeighbors(self.model.cutoff_upper) @@ -58,12 +59,16 @@ def forward( box: Optional[pt.Tensor] = None, q: Optional[pt.Tensor] = None, s: Optional[pt.Tensor] = None, + extra_embedding_args: Optional[Tuple[pt.Tensor]] = None ) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]: assert pt.all(batch == 0) assert box is None, "Box is not supported" x = self.model.embedding(z) + if self.model.reshape_embedding is not None: + x = pt.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1) + x = self.model.reshape_embedding(x) self.neighbors.build(pos) for inter, conv in zip(self.model.interactions, self.convs):