Skip to content

Tensornet-q #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
205f711
add partial_charges arg
AntonioMirarchi Jan 23, 2025
4db794f
add function to process extra args to lightinig module
AntonioMirarchi Jan 23, 2025
bce4b73
update model forward
AntonioMirarchi Jan 23, 2025
91955ed
add extra args to the forward of the architectures
AntonioMirarchi Jan 23, 2025
b84a74d
process extra_args in tensornet
AntonioMirarchi Jan 23, 2025
04f0c9a
Ensure only one of charge, partial_charges, and spin can be True
AntonioMirarchi Jan 23, 2025
7a0e848
process extra args in predict_step
AntonioMirarchi Jan 23, 2025
44ec9a3
add charge and spin
AntonioMirarchi Jan 23, 2025
cd31d1a
make ace and all mmemap compatible with extra_args format (pq for now)
AntonioMirarchi Jan 23, 2025
55cb50d
update pytest for extra_args
AntonioMirarchi Jan 27, 2025
10f03da
Merge branch 'main' into tn_add_extra_args
AntonioMirarchi Jan 27, 2025
9dd605a
update fixed memmap to extra_args nomenclature
AntonioMirarchi Jan 27, 2025
0ebff4f
directly pass nn.module to external
AntonioMirarchi Jan 27, 2025
86b472d
force calculator to use static_shapes=true and check_errors=false if …
AntonioMirarchi Jan 27, 2025
48d792e
update to avoid errror with static_shapes
AntonioMirarchi Jan 27, 2025
81732fa
update test_calculator to properly use cuda
AntonioMirarchi Jan 27, 2025
f23dd50
revert to origin
AntonioMirarchi Jan 27, 2025
7e220b1
fix test_calculator
AntonioMirarchi Jan 27, 2025
0862846
update example.ckpt to use tensornet ckpt instead of ET ckpt
AntonioMirarchi Jan 27, 2025
140af32
Merge branch 'main' of https://github.com/AntonioMirarchi/torchmd-net…
AntonioMirarchi Jan 27, 2025
a10863c
update ckpt example
AntonioMirarchi Jan 27, 2025
2f19830
fix typo
AntonioMirarchi Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified tests/example.ckpt
Binary file not shown.
28 changes: 17 additions & 11 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest
from os.path import dirname, join
from torchmdnet.calculators import External
from torchmdnet.models.model import load_model, create_model

from utils import create_example_batch

# Set relative and absolute tolerance values for float32 precision
# The original test used assert_allclose, which is now deprecated.
# assert_close is used instead, with default tolerances of 1e-5 (rtol) and 1.3e-6 (atol) for torch.float32.
# Here, we manually set rtol and atol to match the original test's tolerances.
rtol = 1e-4
atol = 1e-5

@pytest.mark.parametrize("box", [None, torch.eye(3)])
@pytest.mark.parametrize("use_cuda_graphs", [True, False])
Expand Down Expand Up @@ -39,24 +45,24 @@ def test_compare_forward(box, use_cuda_graphs):
"precision": 32,
}
device = "cpu" if not use_cuda_graphs else "cuda"
model = create_model(args).to(device=device)
c_model = load_model(checkpoint).to(device=device)
g_model = load_model(checkpoint, check_errors=not use_cuda_graphs, static_shapes=use_cuda_graphs).to(device=device)
z, pos, _ = create_example_batch(multiple_batches=False)
z = z.to(device)
pos = pos.to(device)
calc = External(checkpoint, z.unsqueeze(0), use_cuda_graph=False, device=device)
calc = External(c_model, z.unsqueeze(0), use_cuda_graph=False, device=device)
calc_graph = External(
checkpoint, z.unsqueeze(0), use_cuda_graph=use_cuda_graphs, device=device
g_model, z.unsqueeze(0), use_cuda_graph=use_cuda_graphs, device=device
)
calc.model = model
calc_graph.model = model

if box is not None:
box = (box * 2 * args["cutoff_upper"]).unsqueeze(0)

for _ in range(10):
e_calc, f_calc = calc.calculate(pos, box)
e_pred, f_pred = calc_graph.calculate(pos, box)
assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred)

assert_close(e_calc, e_pred, rtol=rtol, atol=atol)
assert_close(f_calc, f_pred, rtol=rtol, atol=atol)

def test_compare_forward_multiple():
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
Expand All @@ -72,5 +78,5 @@ def test_compare_forward_multiple():
torch.cat([torch.zeros(len(z1)), torch.ones(len(z2))]).long(),
)

assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred.view(-1, len(z1), 3))
assert_close(e_calc, e_pred, rtol=rtol, atol=atol)
assert_close(f_calc, f_pred.view(-1, len(z1), 3), rtol=rtol, atol=atol)
2 changes: 1 addition & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def test_example_yamls(fname):

z, pos, batch = create_example_batch()
model(z, pos, batch)
model(z, pos, batch, q=None, s=None)
model(z, pos, batch, q=None, s=None, extra_args=None)
9 changes: 7 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@
@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("use_batch", [True, False])
@mark.parametrize("explicit_q_s", [True, False])
@mark.parametrize("explicit_extra_args", [True, False])
@mark.parametrize("precision", [32, 64])
def test_forward(model_name, use_batch, explicit_q_s, precision):
def test_forward(model_name, use_batch, explicit_q_s, explicit_extra_args, precision):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(
load_example_args(model_name, prior_model=None, precision=precision)
)
batch = batch if use_batch else None
if explicit_q_s:
if explicit_q_s and explicit_extra_args:
model(z, pos, batch=batch, q=None, s=None, extra_args=None)
elif explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
elif explicit_extra_args:
model(z, pos, batch=batch, extra_args=None)
else:
model(z, pos, batch=batch)

Expand Down
4 changes: 4 additions & 0 deletions torchmdnet/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(self, filename, dataset_preload_limit=1024, **kwargs):
("pos", "pos", torch.float32),
("z", "types", torch.long),
]
if "charge" in group:
self.fields.append(("q", "charge", torch.float32))
if "spin" in group:
self.fields.append(("s", "spin", torch.float32))
if "energy" in group:
self.fields.append(("y", "energy", torch.float32))
if "forces" in group:
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get(self, idx):
if "q" in self.properties:
props["q"] = pt.tensor(self.mmaps["q"][idx], dtype=pt.long)
if "pq" in self.properties:
props["pq"] = pt.tensor(self.mmaps["pq"][atoms])
props["partial_charges"] = pt.tensor(self.mmaps["pq"][atoms])
if "dp" in self.properties:
props["dp"] = pt.tensor(self.mmaps["dp"][idx])
# if "mol_idx" in self.properties:
Expand Down
3 changes: 2 additions & 1 deletion torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def forward(
pos.requires_grad_(True)
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(
z, pos, batch, box=box, q=q, s=s
z, pos, batch, box=box, q=q, s=s, extra_args=extra_args

)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
Expand Down
46 changes: 35 additions & 11 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import torch
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
from torch import Tensor, nn
from torchmdnet.models.utils import (
CosineCutoff,
Expand Down Expand Up @@ -61,6 +61,32 @@ def tensor_norm(tensor):
"""Computes Frobenius norm."""
return (tensor**2).sum((-2, -1))

def process_additional_labels(z: torch.Tensor, q: Optional[torch.Tensor], s: Optional[torch.Tensor], extra_args: Optional[Dict[str, torch.Tensor]], batch: torch.Tensor) -> torch.Tensor:
"""
Process additional labels for the model. This function assigns atom-wise properties based on the provided
molecule-wise properties or extra arguments.
Total charge q and spin s are molecule-wise properties. We transform it into an atom-wise property, with all atoms
belonging to the same molecule being assigned the same charge q or spin s.

Args:
batch (Tensor): Batch tensor indicating the molecule each atom belongs to.
z (Tensor): Atomic numbers tensor.
q (Optional[Tensor]): Total charge tensor for each molecule.
s (Optional[Tensor]): Spin tensor for each molecule.
extra_args (Optional[Dict[str, Tensor]]): Dictionary containing additional properties.

Returns:
Tensor: Atom-wise property tensor already scaled by 0.1.
"""
if q is not None:
t = q[batch]
elif s is not None:
t = s[batch]
elif extra_args is not None and 'partial_charges' in extra_args:
t = extra_args['partial_charges']
else:
t = torch.zeros_like(z, device=z.device, dtype=z.dtype)
return t

class TensorNet(nn.Module):
r"""TensorNet's architecture. From
Expand Down Expand Up @@ -226,6 +252,7 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
Expand All @@ -234,16 +261,13 @@ def forward(
edge_vec is not None
), "Distance module did not return directional information"
# Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom
# Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q
if q is None:
q = torch.zeros_like(z, device=z.device, dtype=z.dtype)
else:
q = q[batch]

t = process_additional_labels(z, q, s, extra_args, batch)
zp = z
if self.static_shapes:
mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index)
zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0)
q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0)
t = torch.cat((t, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0)
# I trick the model into thinking that the masked edges pertain to the extra atom
# WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs
edge_index = edge_index.masked_fill(mask, z.shape[0])
Expand All @@ -258,7 +282,7 @@ def forward(
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr, q)
X = layer(X, edge_index, edge_weight, edge_attr, t)
I, A, S = decompose_tensor(X)
x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
x = self.out_norm(x)
Expand Down Expand Up @@ -454,7 +478,7 @@ def forward(
edge_index: Tensor,
edge_weight: Tensor,
edge_attr: Tensor,
q: Tensor,
t: Tensor,
) -> Tensor:
C = self.cutoff(edge_weight)
for linear_scalar in self.linears_scalar:
Expand All @@ -481,7 +505,7 @@ def forward(
if self.equivariance_invariance_group == "O(3)":
A = torch.matmul(msg, Y)
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor((1 + 0.1 * q[..., None, None, None]) * (A + B))
I, A, S = decompose_tensor((1 + 0.1 * t[..., None, None, None]) * (A + B))
if self.equivariance_invariance_group == "SO(3)":
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor(2 * B)
Expand All @@ -491,5 +515,5 @@ def forward(
A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
dX = I + A + S
X = X + dX + (1 + 0.1 * q[..., None, None, None]) * torch.matrix_power(dX, 2)
X = X + dX + (1 + 0.1 * t[..., None, None, None]) * torch.matrix_power(dX, 2)
return X
3 changes: 2 additions & 1 deletion torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
import torch
from torch import Tensor, nn
from torchmdnet.models.utils import (
Expand Down Expand Up @@ -196,6 +196,7 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
x = self.embedding(z)

Expand Down
5 changes: 3 additions & 2 deletions torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
import torch
from torch import Tensor, nn
from torchmdnet.models.utils import (
Expand Down Expand Up @@ -196,8 +196,9 @@ def forward(
pos: Tensor,
batch: Tensor,
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)

Expand Down
5 changes: 3 additions & 2 deletions torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
import torch
from torch import Tensor, nn
from torchmdnet.models.utils import (
Expand Down Expand Up @@ -190,8 +190,9 @@ def forward(
pos: Tensor,
batch: Tensor,
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)

Expand Down
5 changes: 3 additions & 2 deletions torchmdnet/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from abc import abstractmethod, ABCMeta
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
from torch import nn, Tensor


Expand Down Expand Up @@ -45,8 +45,9 @@ def forward(
batch: Tensor,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x, v, z, pos, batch = self.model(z, pos, batch=batch, q=q, s=s)
x, v, z, pos, batch = self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args)

n_samples = len(batch.unique())

Expand Down
23 changes: 17 additions & 6 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def __init__(self, loss_fn, extra_args=None):
def __call__(self, x, batch):
return self.loss_fn(x, batch, **self.extra_args)

def process_extra_args(extra_args, use_partial_charges):
''' Process extra arguments to remove those that are not needed by the model, before passing them to the forward function.'''
for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"):
if a in extra_args:
del extra_args[a]
if not use_partial_charges and 'partial_charges' in extra_args:
del extra_args['partial_charges']
return extra_args

class LNNP(LightningModule):
"""
Expand All @@ -77,6 +85,13 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
hparams["charge"] = False
if "spin" not in hparams:
hparams["spin"] = False
if "partial_charges" not in hparams:
hparams["partial_charges"] = False
# Ensure only one of charge, partial_charges, and spin can be True, otherwise raise a ValueError
if sum([hparams["charge"], hparams["partial_charges"], hparams["spin"]]) > 1:
raise ValueError(
"Only one of 'charge', 'partial_charges', and 'spin' can be True."
)
if "train_loss" not in hparams:
hparams["train_loss"] = "mse_loss"
if "train_loss_arg" not in hparams:
Expand Down Expand Up @@ -184,9 +199,7 @@ def predict_step(self, batch, batch_idx):

with torch.set_grad_enabled(self.hparams.derivative):
extra_args = batch.to_dict()
for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"):
if a in extra_args:
del extra_args[a]
extra_args = process_extra_args(extra_args, self.hparams.partial_charges)
return self(
batch.z,
batch.pos,
Expand Down Expand Up @@ -253,9 +266,7 @@ def step(self, batch, loss_fn_list, stage):
batch = self.data_transform(batch)
with torch.set_grad_enabled(stage == "train" or self.hparams.derivative):
extra_args = batch.to_dict()
for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"):
if a in extra_args:
del extra_args[a]
extra_args = process_extra_args(extra_args, self.hparams.partial_charges)
# TODO: the model doesn't necessarily need to return a derivative once
# Union typing works under TorchScript (https://github.com/pytorch/pytorch/pull/53180)
y, neg_dy = self(
Expand Down
4 changes: 3 additions & 1 deletion torchmdnet/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
from torch import Tensor
import torch as pt
from NNPOps.CFConv import CFConv
from NNPOps.CFConvNeighbors import CFConvNeighbors
Expand Down Expand Up @@ -58,6 +59,7 @@ def forward(
box: Optional[pt.Tensor] = None,
q: Optional[pt.Tensor] = None,
s: Optional[pt.Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]:

assert pt.all(batch == 0)
Expand Down
1 change: 1 addition & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_argparse():
# architectural args
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.')
parser.add_argument('--partial-charges', type=bool, default=False, help='Model needs partial charges. Set this to True if your dataset contains partial charges and you want them passed down to the model.')
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model')
parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')
Expand Down
Loading