Skip to content

Commit

Permalink
Fixed test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Feb 22, 2024
1 parent 58f298f commit 13926ad
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def forward(
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def forward(
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
Expand Down
5 changes: 5 additions & 0 deletions torchmdnet/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, model):

super().__init__()
self.model = model
self.extra_embedding = model.extra_embedding

self.neighbors = CFConvNeighbors(self.model.cutoff_upper)

Expand All @@ -58,12 +59,16 @@ def forward(
box: Optional[pt.Tensor] = None,
q: Optional[pt.Tensor] = None,
s: Optional[pt.Tensor] = None,
extra_embedding_args: Optional[Tuple[pt.Tensor]] = None
) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]:

assert pt.all(batch == 0)
assert box is None, "Box is not supported"

x = self.model.embedding(z)
if self.model.reshape_embedding is not None:
x = pt.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.model.reshape_embedding(x)

self.neighbors.build(pos)
for inter, conv in zip(self.model.interactions, self.convs):
Expand Down

0 comments on commit 13926ad

Please sign in to comment.