diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 3a40ec6f..70476170 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -234,7 +234,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 6fe211b2..5545f84b 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -206,7 +206,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) if self.reshape_embedding is not None: diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 690e34e7..f7760bc2 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -208,7 +208,7 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) if self.reshape_embedding is not None: diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 89655740..9ab6f2bd 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -202,7 +202,7 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, - extra_embedding_args: [Optional[Tuple[Tensor]]] = None + extra_embedding_args: Optional[Tuple[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) if self.reshape_embedding is not None: diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 0c7f5651..6cae3963 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -33,6 +33,7 @@ def __init__(self, model): super().__init__() self.model = model + self.extra_embedding = model.extra_embedding self.neighbors = CFConvNeighbors(self.model.cutoff_upper) @@ -58,12 +59,16 @@ def forward( box: Optional[pt.Tensor] = None, q: Optional[pt.Tensor] = None, s: Optional[pt.Tensor] = None, + extra_embedding_args: Optional[Tuple[pt.Tensor]] = None ) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]: assert pt.all(batch == 0) assert box is None, "Box is not supported" x = self.model.embedding(z) + if self.model.reshape_embedding is not None: + x = pt.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1) + x = self.model.reshape_embedding(x) self.neighbors.build(pos) for inter, conv in zip(self.model.interactions, self.convs):