-
Notifications
You must be signed in to change notification settings - Fork 4
Training
Proper data preprocessing and normalization are crucial steps in training neural network potentials (NNPs). These steps ensure faster convergence, improved numerical stability, and better generalization in predicting molecular energies and forces.
Energies derived from quantum mechanical (QM) methods often include atomic self-energies—intrinsic energy contributions from individual atoms, independent of interatomic interactions. These self-energies correspond to the system’s energy when all interactions between atoms are effectively eliminated, such as when interatomic distances are increased beyond the interaction cutoff (often considered to be infinity).
Removing these self-energy offsets from the total energy has been shown to improve learning rates during training. By shifting the energy distribution closer to zero without changing its width (i.e., (E_min - E_max) remains constant before and after offset removal), the network can focus on learning the energy differences driven by interatomic interactions. This normalization also helps with numerical precision: given that training is typically performed in single precision, large energy offsets can consume valuable bits of floating-point precision, limiting the model's ability to learn finer energy differences.
The atomic self-energies depend on the QM method used to compute the energies in the training dataset. In most datasets, self-energies are provided for each atomic species. However, if this information is unavailable, tools like modelforge can estimate the self-energies using linear regression on the total energies, fitting to the composition of the system. After removing the atomic self-energy offset, the energies are converted to PyTorch tensors and cast to single precision for training. Care should be taken to retain double precision in the preprocessing steps to avoid precision loss during offset removal.
The choice of loss function is critical in optimizing the NNP. When only energies are used to drive the training, the mean squared error (MSE) of the predicted and actual energies is the natural choice.
For a dataset containing
with
For more accuracy, forces can be included in the training objective. This inclusion improves the model's ability to predict not just total energies but also the forces acting on each atom, which is crucial for molecular dynamics (MD) simulations and geometry optimizations.
When including forces, the loss function is extended to:
with:
-
$\hat{U}_k$ as the predicted potential energy for data point$k$ -
$F_{i, \alpha}^{k}$ as the reference force component$\alpha$ of atom$i$ at data point$k$ - r_{i,\alpha}^{k} as the coordinate
$\alpha$ of atom$i$ at data point$k$
This loss function comprises two terms:
- Energy Loss: averaged over the number of atoms in each data point to account for varying sizes.
- Force Loss: computed as the mean squared error between the predicted and reference forces, summed over all atoms and coordinate components in each data point.
By averaging over the number of atoms and data points, we ensure that each molecule contributes equally to the total loss, preventing larger molecules from dominating the training process.
In addition to energies and forces, including dipole moments as training targets can further enhance the predictive capabilities of NNPs, particularly for properties related to molecular polarity and electrostatics.
The dipole moment loss operates on the predicted partial atomic charges and the molecular geometry. It ensures that the predicted partial charges not only sum up to the correct total charge but also reproduce the reference dipole moment.
The loss has the following form:
with
-
$q_{i}^{k}$ as the predicted partial charges for atom$i$ in data point$k$ -
$Q_{\text{ref}}^{k}$ as the reference total charge of the data point$k$ -
$r_{i, \alpha}^{k}$ as the$\alpha$ component of the coordinate of atom$i$ -
$p^{\alpha, k}_{\text{ref}}$ as the reference dipole moment component$\alpha$ for data point$k$
This loss function comprises two terms:
- total charge error: ensures that the sum of predicted partial charges equals the reference total charge.
- dipole moment error: ensures that the computed dipole moment from the predicted charges and atomic positions matches the reference dipole moment. By incorporating this loss, the model learns to predict partial charges that are consistent with both the total charge and the dipole moment of the molecule, improving its ability to model electrostatic interactions accurately.
In our codebase, forces are computed using PyTorch's automatic differentiation. Within the CalculateProperties class, the predicted forces are obtained by differentiating the predicted energy with respect to atomic positions:
# Ensure gradients are enabled
per_molecule_energy_predict.requires_grad_(True)
nnp_input.positions.requires_grad_(True)
# Compute the gradient (forces) from the predicted energies
grad = torch.autograd.grad(
outputs=per_molecule_energy_predict,
inputs=nnp_input.positions,
grad_outputs=torch.ones_like(per_molecule_energy_predict),
create_graph=train_mode,
retain_graph=train_mode,
allow_unused=True,
)[0]
# Forces are the negative gradient of energy
per_atom_force_predict = -grad
This approach ensures that the model learns force predictions consistent with energy gradients, adhering to physical laws.
To balance the contributions of different loss components (energies, forces, charges, etc.), we scale them appropriately:
- per-atom scaling: for properties defined per atom (e.g., forces), errors are computed per atom and then aggregated to a per-molecule error by summing over atoms in each molecule.
- (optionally) normalization by atom count: The aggregated errors are scaled by the number of atoms to compute mean errors, ensuring that molecules with different sizes contribute comparably to the loss. That also happens if the per-atom property has more dimensions than 1, e.g. the force. Here, we scale by the components of the force.
The Loss
class orchestrates this process, utilizing specific error calculation classes for different properties. For instance:
-
Energy Loss: uses
PerMoleculeSquaredError
with optional scaling by the number of atoms -
Force Loss: Uses
FromPerAtomToPerMoleculeSquaredError
to compute per-atom errors and aggregate them -
Charge and Dipole Moment Losses: Utilize
TotalChargeError
andDipoleMomentError
classes, respectively
class Loss(nn.Module):
# ...
def forward(self, predict_target: Dict[str, torch.Tensor], batch: NNPInput) -> Dict[str, torch.Tensor]:
loss_dict = {}
total_loss = torch.zeros_like(batch.metadata.E)
for prop in self.loss_property:
loss_fn = self.loss_functions[prop]
prop_loss = loss_fn(
predict_target[f"{prop}_predict"],
predict_target[f"{prop}_true"],
batch,
)
weighted_loss = self.weights[prop] * prop_loss
total_loss += weighted_loss
loss_dict[prop] = prop_loss
loss_dict["total_loss"] = total_loss
return loss_dict
This flexible structure allows users to specify which properties to include in the loss function and assign custom weights to them.
modelforge supports training on additional properties like total charge and dipole moment, which are essential for accurate electrostatic interactions:
- Total Charge Loss: Ensures that the sum of predicted partial atomic charges matches the reference total charge
- Dipole Moment Loss: Aligns the predicted molecular dipole moment (computed from predicted charges and positions) with the reference dipole moment These properties are integrated into the loss function similarly, with their specific error computation classes and weights.
During training and evaluation, we track various metrics using PyTorch Lightning's MetricCollection
:
def create_error_metrics(
loss_properties: List[str],
is_loss: bool = False,
) -> ModuleDict:
from torchmetrics import MetricCollection, MeanAbsoluteError, MeanSquaredError
if is_loss:
metric_dict = ModuleDict(
{prop: MetricCollection([MeanMetric()]) for prop in loss_properties}
)
metric_dict["total_loss"] = MetricCollection([MeanMetric()])
else:
metric_dict = ModuleDict(
{
prop: MetricCollection(
[MeanAbsoluteError(), MeanSquaredError(squared=False)]
)
for prop in loss_properties
}
)
return metric_dict
This setup facilitates comprehensive monitoring of model performance across different properties and datasets.