Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support arbitrary outputs in TorchMD_Net #239

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 160 additions & 23 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@
import warnings


def create_model(args, prior_model=None, mean=None, std=None):
"""Create a model from the given arguments.
def create_representation_model(args):
"""Create a representation model from the given arguments.
See :func:`get_args` in scripts/train.py for a description of the arguments.
Parameters
----------
args (dict): Arguments for the model.
prior_model (nn.Module, optional): Prior model to use. Defaults to None.
mean (torch.Tensor, optional): Mean of the training data. Defaults to None.
std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None.
Returns
-------
nn.Module: An instance of the TorchMD_Net model.
nn.Module: An instance of the TorchMD_Net representation model.
"""
dtype = dtype_mapping[args["precision"]]
shared_args = dict(
Expand All @@ -38,12 +35,8 @@ def create_model(args, prior_model=None, mean=None, std=None):
max_num_neighbors=args["max_num_neighbors"],
dtype=dtype
)

# representation network
if args["model"] == "graph-network":
from torchmdnet.models.torchmd_gn import TorchMD_GN

is_equivariant = False
representation_model = TorchMD_GN(
num_filters=args["embedding_dimension"],
aggr=args["aggr"],
Expand All @@ -52,8 +45,6 @@ def create_model(args, prior_model=None, mean=None, std=None):
)
elif args["model"] == "transformer":
from torchmdnet.models.torchmd_t import TorchMD_T

is_equivariant = False
representation_model = TorchMD_T(
attn_activation=args["attn_activation"],
num_heads=args["num_heads"],
Expand All @@ -63,8 +54,6 @@ def create_model(args, prior_model=None, mean=None, std=None):
)
elif args["model"] == "equivariant-transformer":
from torchmdnet.models.torchmd_et import TorchMD_ET

is_equivariant = True
representation_model = TorchMD_ET(
attn_activation=args["attn_activation"],
num_heads=args["num_heads"],
Expand All @@ -74,41 +63,53 @@ def create_model(args, prior_model=None, mean=None, std=None):
)
elif args["model"] == "tensornet":
from torchmdnet.models.tensornet import TensorNet

# Setting is_equivariant to False to enforce the use of Scalar output module instead of EquivariantScalar
is_equivariant = False
representation_model = TensorNet(
equivariance_invariance_group=args["equivariance_invariance_group"],
**shared_args,
)
else:
raise ValueError(f'Unknown architecture: {args["model"]}')
return representation_model

def create_model(args, prior_model=None, mean=None, std=None):
"""Create a model from the given arguments.
See :func:`get_args` in scripts/train.py for a description of the arguments.
Parameters
----------
args (dict): Arguments for the model.
prior_model (nn.Module, optional): Prior model to use. Defaults to None.
mean (torch.Tensor, optional): Mean of the training data. Defaults to None.
std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None.
Returns
-------
nn.Module: An instance of the TorchMD_Net model.
"""
dtype = dtype_mapping[args["precision"]]
# representation network
representation_model = create_representation_model(args)
# atom filter
if not args["derivative"] and args["atom_filter"] > -1:
representation_model = AtomFilter(representation_model, args["atom_filter"])
elif args["atom_filter"] > -1:
raise ValueError("Derivative and atom filter can't be used together")

# prior model
if args["prior_model"] and prior_model is None:
# instantiate prior model if it was not passed to create_model (i.e. when loading a model)
prior_model = create_prior_models(args)

# create output network
is_equivariant = args["model"] == "equivariant-transformer"
output_prefix = "Equivariant" if is_equivariant else ""
output_model = getattr(output_modules, output_prefix + args["output_model"])(
args["embedding_dimension"],
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=dtype,
)

# combine representation and output network
model = TorchMD_Net(
model = MultiHeadTorchMD_Net(
representation_model,
output_model,
prior_model=prior_model,
# output_model,
# prior_model=prior_model,
mean=mean,
std=std,
derivative=args["derivative"],
Expand All @@ -118,6 +119,7 @@ def create_model(args, prior_model=None, mean=None, std=None):


def load_model(filepath, args=None, device="cpu", **kwargs):
raise NotImplementedError("load_model is not implemented yet")
ckpt = torch.load(filepath, map_location="cpu")
if args is None:
args = ckpt["hyper_parameters"]
Expand Down Expand Up @@ -297,3 +299,138 @@ def forward(
return y, -dy
# TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180)
return y, None

from torchmdnet.models.utils import scatter, act_class_mapping

class BaseHead(nn.Module):
def __init__(self, dtype=torch.float32):
super(BaseHead, self).__init__()
self.dtype = dtype

def reset_parameters(self):
pass

def per_point(self, point_features, results, z, pos, batch, extra_args):
return point_features, results

def per_sample(self, point_features, results, z, pos, batch, extra_args):
return point_features, results

class EnergyHead(BaseHead):
def __init__(self,
hidden_channels,
activation="silu",
dtype=torch.float32):
super(EnergyHead, self).__init__(dtype=dtype)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
)
self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)

def per_point(self, point_features, results, z, pos, batch, extra_args):
results["energy"] = self.output_network(point_features)
return point_features, results

def per_sample(self, point_features, results, z, pos, batch, extra_args):
results["energy"] = scatter(results["energy"], batch, dim=0)
return point_features, results

class PointChargeHead(BaseHead):
def __init__(self,
hidden_channels,
activation="silu",
dtype=torch.float32):
super(PointChargeHead, self).__init__(dtype=dtype)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
)
self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)

def per_point(self, point_features, results, z, pos, batch, extra_args):
results["charge"] = self.output_network(point_features)
return point_features, results

def per_sample(self, point_features, results, z, pos, batch, extra_args):
return point_features, results

class ForceHead(BaseHead):
def __init__(self,
dtype=torch.float32):
super(ForceHead, self).__init__(dtype=dtype)
pass

def per_sample(self, point_features, results, z, pos, batch, extra_args):
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(results["energy"])]
results["force"] = -grad([results["energy"]],
[pos],
grad_outputs=grad_outputs,
create_graph=self.training,
retain_graph=self.training)[0]
return point_features, results

class MultiHeadTorchMD_Net(nn.Module):
def __init__(
self,
representation_model,
head_list = None,
mean=None,
std=None,
derivative=False,
dtype=torch.float32,
):
super(MultiHeadTorchMD_Net, self).__init__()
self.representation_model = representation_model.to(dtype=dtype)
self.derivative = derivative
self.head_list = nn.ModuleList([EnergyHead(representation_model.hidden_channels, dtype=dtype)])
if derivative:
self.head_list.append(ForceHead(dtype=dtype))
mean = torch.scalar_tensor(0) if mean is None else mean
self.register_buffer("mean", mean.to(dtype=dtype))
std = torch.scalar_tensor(1) if std is None else std
self.register_buffer("std", std.to(dtype=dtype))
self.reset_parameters()

def reset_parameters(self):
self.representation_model.reset_parameters()
for head in self.head_list:
head.reset_parameters()

def forward(
self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None
) -> Dict[str, Tensor]:
assert z.dim() == 1 and z.dtype == torch.long
batch = torch.zeros_like(z) if batch is None else batch
if self.derivative:
pos.requires_grad_(True)
results = {}
# run the potentially wrapped representation model
point_features = self.representation_model(z, pos, batch, q=q, s=s)
for head in self.head_list:
point_features, results = head.per_point(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)
for head in self.head_list:
point_features, results = head.per_sample(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)
return results
Loading
Loading