Skip to content

Commit

Permalink
allow additional features in embedding (#204)
Browse files Browse the repository at this point in the history
* flixible embedding
* update
* fix imports
* bugfix
  • Loading branch information
wiederm authored Jul 25, 2024
1 parent 0aebaf3 commit 8e4e5b1
Show file tree
Hide file tree
Showing 21 changed files with 780 additions and 402 deletions.
23 changes: 18 additions & 5 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class NNPInput:
atomic_subsystem_indices: torch.Tensor
total_charge: torch.Tensor
pair_list: Optional[torch.Tensor] = None
partial_charge: Optional[torch.Tensor] = None

def to(
self,
Expand All @@ -88,15 +89,27 @@ def to(
self.positions = self.positions.to(device)
self.atomic_subsystem_indices = self.atomic_subsystem_indices.to(device)
self.total_charge = self.total_charge.to(device)
if self.pair_list is not None:
self.pair_list = self.pair_list.to(device)
self.pair_list = (
self.pair_list.to(device)
if self.pair_list is not None
else self.pair_list
)
self.partial_charge = (
self.partial_charge.to(device)
if self.partial_charge is not None
else self.partial_charge
)
if dtype:
self.positions = self.positions.to(dtype)
return self

def __post_init__(self):
# Set dtype and convert units if necessary
self.atomic_numbers = self.atomic_numbers.to(torch.int32)

self.partial_charge = (
self.atomic_numbers.to(torch.int32) if self.partial_charge else None
)
self.atomic_subsystem_indices = self.atomic_subsystem_indices.to(torch.int32)
self.total_charge = self.total_charge.to(torch.int32)

Expand Down Expand Up @@ -673,9 +686,9 @@ def _from_hdf5(self) -> None:

if all(property_found):
# we want to exclude conformers with NaN values for any property of interest
configs_nan_by_prop: Dict[
str, np.ndarray
] = OrderedDict() # ndarray.size (n_configs, )
configs_nan_by_prop: Dict[str, np.ndarray] = (
OrderedDict()
) # ndarray.size (n_configs, )
for value in list(series_mol_data.keys()) + list(
series_atom_data.keys()
):
Expand Down
4 changes: 2 additions & 2 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Tuple
from .models import InputPreparation, BaseNetwork, CoreNetwork
from .models import ComputeInteractingAtomPairs, BaseNetwork, CoreNetwork

import torch
from loguru import logger as log
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(
# number of elements in ANI2x
self.num_species = 7

log.debug("Initializing ANI model.")
log.debug("Initializing the ANI2x architecture.")
super().__init__()

# Initialize representation block
Expand Down
20 changes: 10 additions & 10 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def generate_model(
raise NotImplementedError(f"Unsupported 'use' value: {use}")


class InputPreparation(torch.nn.Module):
class ComputeInteractingAtomPairs(torch.nn.Module):
def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True):
"""
A module for preparing input data, including the calculation of pair lists, distances (d_ij), and displacement vectors (r_ij) for molecular simulations.
Expand Down Expand Up @@ -891,7 +891,7 @@ def __init__(
self,
*,
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic,
dataset_statistic: Optional[Dict[str, float]],
cutoff: unit.Quantity,
):
"""
Expand All @@ -916,7 +916,7 @@ def __init__(
raise RuntimeError(
"The only_unique_pairs attribute is not set in the child class. Please set it to True or False before calling super().__init__."
)
self.input_preparation = InputPreparation(
self.compute_interacting_pairs = ComputeInteractingAtomPairs(
cutoff=_convert(cutoff), only_unique_pairs=self.only_unique_pairs
)

Expand Down Expand Up @@ -962,15 +962,15 @@ def load_state_dict(

super().load_state_dict(filtered_state_dict, strict=strict, assign=assign)

def prepare_input(self, data):
def prepare_pairwise_properties(self, data):

self.input_preparation._input_checks(data)
return self.input_preparation.prepare_inputs(data)
self.compute_interacting_pairs._input_checks(data)
return self.compute_interacting_pairs.prepare_inputs(data)

def compute(self, data, core_input):
return self.core_module(data, core_input)

def forward(self, data: NNPInput):
def forward(self, input_data: NNPInput):
"""
Executes the forward pass of the model.
This method performs input checks, prepares the inputs,
Expand All @@ -987,10 +987,10 @@ def forward(self, data: NNPInput):
The outputs computed by the core network.
"""

# perform input checks
core_input = self.prepare_input(data)
# compute all interacting pairs with distances
pairwise_properties = self.prepare_pairwise_properties(input_data)
# prepare the input for the forward pass
output = self.compute(data, core_input)
output = self.compute(input_data, pairwise_properties)
# perform postprocessing operations
processed_output = self.postprocessing(output)
return processed_output
Expand Down
57 changes: 28 additions & 29 deletions modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import torch.nn.functional as F
from loguru import logger as log
from openff.units import unit
from .models import InputPreparation, NNPInput, BaseNetwork, CoreNetwork
from .models import NNPInput, BaseNetwork, CoreNetwork

from .utils import Dense
from typing import List

if TYPE_CHECKING:
from .models import PairListOutputs
Expand Down Expand Up @@ -82,52 +83,54 @@ class PaiNNCore(CoreNetwork):

def __init__(
self,
max_Z: int = 100,
number_of_atom_features: int = 64,
number_of_radial_basis_functions: int = 16,
cutoff: unit.Quantity = 5 * unit.angstrom,
number_of_interaction_modules: int = 2,
shared_interactions: bool = False,
shared_filters: bool = False,
featurization_config: Dict[str, Union[List[str], int]],
number_of_radial_basis_functions: int,
cutoff: unit.Quantity,
number_of_interaction_modules: int,
shared_interactions: bool,
shared_filters: bool,
epsilon: float = 1e-8,
):
log.debug("Initializing PaiNN model.")
log.debug("Initializing the PaiNN architecture.")
super().__init__()

self.number_of_interaction_modules = number_of_interaction_modules
self.number_of_atom_features = number_of_atom_features
self.shared_filters = shared_filters

# embedding
from modelforge.potential.utils import Embedding

self.embedding_module = Embedding(max_Z, number_of_atom_features)
# featurize the atomic input
from modelforge.potential.utils import FeaturizeInput

self.featurize_input = FeaturizeInput(featurization_config)
number_of_per_atom_features = featurization_config[
"number_of_per_atom_features"
]
# initialize representation block
self.representation_module = PaiNNRepresentation(
cutoff,
number_of_radial_basis_functions,
number_of_interaction_modules,
number_of_atom_features,
number_of_per_atom_features,
shared_filters,
)

# initialize the interaction and mixing networks
self.interaction_modules = nn.ModuleList(
PaiNNInteraction(number_of_atom_features, activation=F.silu)
PaiNNInteraction(number_of_per_atom_features, activation=F.silu)
for _ in range(number_of_interaction_modules)
)
self.mixing_modules = nn.ModuleList(
PaiNNMixing(number_of_atom_features, activation=F.silu, epsilon=epsilon)
PaiNNMixing(number_of_per_atom_features, activation=F.silu, epsilon=epsilon)
for _ in range(number_of_interaction_modules)
)

self.energy_layer = nn.Sequential(
Dense(
number_of_atom_features, number_of_atom_features, activation=nn.ReLU()
number_of_per_atom_features,
number_of_per_atom_features,
activation=nn.ReLU(),
),
Dense(
number_of_atom_features,
number_of_per_atom_features,
1,
),
)
Expand All @@ -148,10 +151,8 @@ def _model_specific_input_preparation(
atomic_numbers=data.atomic_numbers,
atomic_subsystem_indices=data.atomic_subsystem_indices,
total_charge=data.total_charge,
atomic_embedding=self.embedding_module(
data.atomic_numbers
), # atom embedding
)
atomic_embedding=self.featurize_input(data),
) # add per-atom properties and embedding

return nnp_input

Expand Down Expand Up @@ -488,15 +489,14 @@ def forward(self, q: torch.Tensor, mu: torch.Tensor):
return q, mu


from .models import InputPreparation, NNPInput, BaseNetwork
from .models import ComputeInteractingAtomPairs, NNPInput, BaseNetwork
from typing import List


class PaiNN(BaseNetwork):
def __init__(
self,
max_Z: int,
number_of_atom_features: int,
featurization: Dict[str, Union[List[str], int]],
number_of_radial_basis_functions: int,
cutoff: Union[unit.Quantity, str],
number_of_interaction_modules: int,
Expand All @@ -518,8 +518,7 @@ def __init__(
)

self.core_module = PaiNNCore(
max_Z=max_Z,
number_of_atom_features=number_of_atom_features,
featurization_config=featurization,
number_of_radial_basis_functions=number_of_radial_basis_functions,
cutoff=_convert(cutoff),
number_of_interaction_modules=number_of_interaction_modules,
Expand All @@ -538,7 +537,7 @@ def _config_prior(self):
from modelforge.potential.utils import shared_config_prior

prior = {
"number_of_atom_features": tune.randint(2, 256),
"number_of_per_atom_features": tune.randint(2, 256),
"number_of_interaction_modules": tune.randint(1, 5),
"cutoff": tune.uniform(5, 10),
"number_of_radial_basis_functions": tune.randint(8, 32),
Expand Down
Loading

0 comments on commit 8e4e5b1

Please sign in to comment.