Skip to content

Training

Marcus Wieder edited this page Oct 1, 2024 · 18 revisions

Data processing and normalization

Proper data preprocessing and normalization are crucial steps in training neural network potentials (NNPs). These steps ensure faster convergence, improved numerical stability, and better generalization in predicting molecular energies and forces.

Atomic self-energies

Energies derived from quantum mechanical (QM) methods often include atomic self-energies—intrinsic energy contributions from individual atoms, independent of interatomic interactions. These self-energies correspond to the system’s energy when all interactions between atoms are effectively eliminated, such as when interatomic distances are increased beyond the interaction cutoff (often considered to be infinity).

Removing these self-energy offsets from the total energy has been shown to improve learning rates during training. By shifting the energy distribution closer to zero without changing its width (i.e., (E_min - E_max) remains constant before and after offset removal), the network can focus on learning the energy differences driven by interatomic interactions. This normalization also helps with numerical precision: given that training is typically performed in single precision, large energy offsets can consume valuable bits of floating-point precision, limiting the model's ability to learn finer energy differences.

Practical Considerations

The atomic self-energies depend on the QM method used to compute the energies in the training dataset. In most datasets, self-energies are provided for each atomic species. However, if this information is unavailable, tools like modelforge can estimate the self-energies using linear regression on the total energies, fitting to the composition of the system. After removing the atomic self-energy offset, the energies are converted to PyTorch tensors and cast to single precision for training. Care should be taken to retain double precision in the preprocessing steps to avoid precision loss during offset removal.

Loss

The choice of loss function is critical in optimizing the NNP. When only energies are used to drive the training, the mean squared error (MSE) of the predicted and actual energies is the natural choice.

For a dataset containing $M$ data points (e.g., molecules or configurations), the energy loss is defined as:

$$ L_E = \frac{1}{M} \sum_{k=1}^{M} \frac{1}{N_k}(\hat{E}_k - E_k)^{2} $$

with $N_k$ as the number of atoms in the system, $\hat{E}_k$ is the predicted energy for data point $k$ and $E_k$ is the reference (QM) energy. By scaling the energy difference by the number of atoms $N_k$, we ensure that each data point contributes equally to the loss function, regardless of its size. This is particularly important when the dataset includes molecules or configurations with varying numbers of atoms.

Incorporating Forces

For more accuracy, forces can be included in the training objective. This inclusion improves the model's ability to predict not just total energies but also the forces acting on each atom, which is crucial for molecular dynamics (MD) simulations and geometry optimizations.

When including forces, the loss function is extended to:

$$ L = \frac{1}{M} \sum_{k=1}^{M} (\frac{1}{N_k} (\hat{E_{k}} - \hat{E_{k}})^{2} + \frac{1}{3N_{k}} \sum_{i=1}^{N_{k}} \sum\limits_{\alpha=1}^{3} (\frac{\partial \hat{U}_{k}}{\partial r_{i, \alpha}^{k}} - F_{i, \alpha}^{k})^2 $$

with:

  • $\hat{U}_k$ as the predicted potential energy for data point $k$
  • $F_{i, \alpha}^{k}$ as the reference force component $\alpha$ of atom $i$ at data point $k$
  • r_{i,\alpha}^{k} as the coordinate $\alpha$ of atom $i$ at data point $k$

This loss function comprises two terms:

  • Energy Loss: averaged over the number of atoms in each data point to account for varying sizes.
  • Force Loss: computed as the mean squared error between the predicted and reference forces, summed over all atoms and coordinate components in each data point.

By averaging over the number of atoms and data points, we ensure that each molecule contributes equally to the total loss, preventing larger molecules from dominating the training process.

Incorporating Dipole Moments

In addition to energies and forces, including dipole moments as training targets can further enhance the predictive capabilities of NNPs, particularly for properties related to molecular polarity and electrostatics.

Dipole Moment Loss

The dipole moment loss operates on the predicted partial atomic charges and the molecular geometry. It ensures that the predicted partial charges not only sum up to the correct total charge but also reproduce the reference dipole moment.

The loss has the following form:

$$ L = \frac{1}{M} \sum_{k=1}^{M} [| \sum_{i=1}^{N_{k}} q_{i}^{k} - Q_{\text{ref}}^{k}| + \sum_{\alpha=1}^3 | \sum_{i=1}^{N_k} q_{i}^{k} r_{i, \alpha}^{k} - p^{\alpha, k}_{\text{ref}} | $$

with

  • $q_{i}^{k}$ as the predicted partial charges for atom $i$ in data point $k$
  • $Q_{\text{ref}}^{k}$ as the reference total charge of the data point $k$
  • $r_{i, \alpha}^{k}$ as the $\alpha$ component of the coordinate of atom $i$
  • $p^{\alpha, k}_{\text{ref}}$ as the reference dipole moment component $\alpha$ for data point $k$

This loss function comprises two terms:

  • total charge error: ensures that the sum of predicted partial charges equals the reference total charge.
  • dipole moment error: ensures that the computed dipole moment from the predicted charges and atomic positions matches the reference dipole moment. By incorporating this loss, the model learns to predict partial charges that are consistent with both the total charge and the dipole moment of the molecule, improving its ability to model electrostatic interactions accurately.

Implementation Details

Force Computation via Automatic Differentiation

In our codebase, forces are computed using PyTorch's automatic differentiation. Within the CalculateProperties class, the predicted forces are obtained by differentiating the predicted energy with respect to atomic positions:

# Ensure gradients are enabled
per_molecule_energy_predict.requires_grad_(True)
nnp_input.positions.requires_grad_(True)

# Compute the gradient (forces) from the predicted energies
grad = torch.autograd.grad(
    outputs=per_molecule_energy_predict,
    inputs=nnp_input.positions,
    grad_outputs=torch.ones_like(per_molecule_energy_predict),
    create_graph=train_mode,
    retain_graph=train_mode,
    allow_unused=True,
)[0]

# Forces are the negative gradient of energy
per_atom_force_predict = -grad

This approach ensures that the model learns force predictions consistent with energy gradients, adhering to physical laws.

Loss Scaling and Aggregation

To balance the contributions of different loss components (energies, forces, charges, etc.), we scale them appropriately:

  • per-atom scaling: for properties defined per atom (e.g., forces), errors are computed per atom and then aggregated to a per-molecule error by summing over atoms in each molecule.
  • (optionally) normalization by atom count: The aggregated errors are scaled by the number of atoms to compute mean errors, ensuring that molecules with different sizes contribute comparably to the loss. That also happens if the per-atom property has more dimensions than 1, e.g. the force. Here, we scale by the components of the force.

The Loss class orchestrates this process, utilizing specific error calculation classes for different properties. For instance:

  • Energy Loss: uses PerMoleculeSquaredError with optional scaling by the number of atoms
  • Force Loss: Uses FromPerAtomToPerMoleculeSquaredError to compute per-atom errors and aggregate them
  • Charge and Dipole Moment Losses: Utilize TotalChargeError and DipoleMomentError classes, respectively
class Loss(nn.Module):
    # ...
    def forward(self, predict_target: Dict[str, torch.Tensor], batch: NNPInput) -> Dict[str, torch.Tensor]:
        loss_dict = {}
        total_loss = torch.zeros_like(batch.metadata.E)

        for prop in self.loss_property:
            loss_fn = self.loss_functions[prop]
            prop_loss = loss_fn(
                predict_target[f"{prop}_predict"],
                predict_target[f"{prop}_true"],
                batch,
            )
            weighted_loss = self.weights[prop] * prop_loss
            total_loss += weighted_loss
            loss_dict[prop] = prop_loss

        loss_dict["total_loss"] = total_loss
        return loss_dict

This flexible structure allows users to specify which properties to include in the loss function and assign custom weights to them.

Extending to Additional Properties

modelforge supports training on additional properties like total charge and dipole moment, which are essential for accurate electrostatic interactions:

  • Total Charge Loss: Ensures that the sum of predicted partial atomic charges matches the reference total charge
  • Dipole Moment Loss: Aligns the predicted molecular dipole moment (computed from predicted charges and positions) with the reference dipole moment These properties are integrated into the loss function similarly, with their specific error computation classes and weights.

Metric Computation

During training and evaluation, we track various metrics using PyTorch Lightning's MetricCollection:

def create_error_metrics(
    loss_properties: List[str],
    is_loss: bool = False,
) -> ModuleDict:
    from torchmetrics import MetricCollection, MeanAbsoluteError, MeanSquaredError

    if is_loss:
        metric_dict = ModuleDict(
            {prop: MetricCollection([MeanMetric()]) for prop in loss_properties}
        )
        metric_dict["total_loss"] = MetricCollection([MeanMetric()])
    else:
        metric_dict = ModuleDict(
            {
                prop: MetricCollection(
                    [MeanAbsoluteError(), MeanSquaredError(squared=False)]
                )
                for prop in loss_properties
            }
        )
    return metric_dict

This setup facilitates comprehensive monitoring of model performance across different properties and datasets.

Clone this wiki locally