-
Notifications
You must be signed in to change notification settings - Fork 4
Training
Proper data preprocessing and normalization are crucial steps in training neural network potentials (NNPs). These steps ensure faster convergence, improved numerical stability, and better generalization in predicting molecular energies and forces.
Energies derived from quantum mechanical (QM) methods often include atomic self-energies—intrinsic energy contributions from individual atoms, independent of interatomic interactions. These self-energies correspond to the system’s energy when all interactions between atoms are effectively eliminated, such as when interatomic distances are increased beyond the interaction cutoff (often considered to be infinity).
Removing these self-energy offsets from the total energy has been shown to improve learning rates during training. By shifting the energy distribution closer to zero without changing its width (i.e., (E_min - E_max) remains constant before and after offset removal), the network can focus on learning the energy differences driven by interatomic interactions. This normalization also helps with numerical precision: given that training is typically performed in single precision, large energy offsets can consume valuable bits of floating-point precision, limiting the model's ability to learn finer energy differences.
The atomic self-energies depend on the QM method used to compute the energies in the training dataset. In most datasets, self-energies are provided for each atomic species. However, if this information is unavailable, tools like modelforge can estimate the self-energies using linear regression on the total energies, fitting to the composition of the system. After removing the atomic self-energy offset, the energies are converted to PyTorch tensors and cast to single precision for training. Care should be taken to retain double precision in the preprocessing steps to avoid precision loss during offset removal.
The choice of loss function is critical in optimizing the NNP. When only energies are used to drive the training, the mean squared error (MSE) of the predicted and actual energies is the natural choice:
with
with
In our codebase, forces are computed using PyTorch's automatic differentiation. Within the CalculateProperties class, the predicted forces are obtained by differentiating the predicted energy with respect to atomic positions:
# Ensure gradients are enabled
per_molecule_energy_predict.requires_grad_(True)
nnp_input.positions.requires_grad_(True)
# Compute the gradient (forces) from the predicted energies
grad = torch.autograd.grad(
outputs=per_molecule_energy_predict,
inputs=nnp_input.positions,
grad_outputs=torch.ones_like(per_molecule_energy_predict),
create_graph=train_mode,
retain_graph=train_mode,
allow_unused=True,
)[0]
# Forces are the negative gradient of energy
per_atom_force_predict = -grad
This approach ensures that the model learns force predictions consistent with energy gradients, adhering to physical laws.
To balance the contributions of different loss components (energies, forces, charges, etc.), we scale them appropriately:
- per-atom scaling: for properties defined per atom (e.g., forces), errors are computed per atom and then aggregated to a per-molecule error by summing over atoms in each molecule.
- (optionally) normalization by atom count: The aggregated errors are scaled by the number of atoms to compute mean errors, ensuring that molecules with different sizes contribute comparably to the loss. That also happens if the per-atom property has more dimensions than 1, e.g. the force. Here, we scale by the components of the force.
The Loss
class orchestrates this process, utilizing specific error calculation classes for different properties. For instance:
-
Energy Loss: uses
PerMoleculeSquaredError
with optional scaling by the number of atoms -
Force Loss: Uses
FromPerAtomToPerMoleculeSquaredError
to compute per-atom errors and aggregate them -
Charge and Dipole Moment Losses: Utilize
TotalChargeError
andDipoleMomentError
classes, respectively
class Loss(nn.Module):
# ...
def forward(self, predict_target: Dict[str, torch.Tensor], batch: NNPInput) -> Dict[str, torch.Tensor]:
loss_dict = {}
total_loss = torch.zeros_like(batch.metadata.E)
for prop in self.loss_property:
loss_fn = self.loss_functions[prop]
prop_loss = loss_fn(
predict_target[f"{prop}_predict"],
predict_target[f"{prop}_true"],
batch,
)
weighted_loss = self.weights[prop] * prop_loss
total_loss += weighted_loss
loss_dict[prop] = prop_loss
loss_dict["total_loss"] = total_loss
return loss_dict
This flexible structure allows users to specify which properties to include in the loss function and assign custom weights to them.
modelforge supports training on additional properties like total charge and dipole moment, which are essential for accurate electrostatic interactions:
- Total Charge Loss: Ensures that the sum of predicted partial atomic charges matches the reference total charge
- Dipole Moment Loss: Aligns the predicted molecular dipole moment (computed from predicted charges and positions) with the reference dipole moment These properties are integrated into the loss function similarly, with their specific error computation classes and weights.
During training and evaluation, we track various metrics using PyTorch Lightning's MetricCollection
:
def create_error_metrics(
loss_properties: List[str],
is_loss: bool = False,
) -> ModuleDict:
from torchmetrics import MetricCollection, MeanAbsoluteError, MeanSquaredError
if is_loss:
metric_dict = ModuleDict(
{prop: MetricCollection([MeanMetric()]) for prop in loss_properties}
)
metric_dict["total_loss"] = MetricCollection([MeanMetric()])
else:
metric_dict = ModuleDict(
{
prop: MetricCollection(
[MeanAbsoluteError(), MeanSquaredError(squared=False)]
)
for prop in loss_properties
}
)
return metric_dict
This setup facilitates comprehensive monitoring of model performance across different properties and datasets.