Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow additional features in embedding #204

Merged
merged 22 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class NNPInput:
atomic_subsystem_indices: torch.Tensor
total_charge: torch.Tensor
pair_list: Optional[torch.Tensor] = None
partial_charge: Optional[torch.Tensor] = None

def to(
self,
Expand All @@ -88,15 +89,27 @@ def to(
self.positions = self.positions.to(device)
self.atomic_subsystem_indices = self.atomic_subsystem_indices.to(device)
self.total_charge = self.total_charge.to(device)
if self.pair_list is not None:
self.pair_list = self.pair_list.to(device)
self.pair_list = (
self.pair_list.to(device)
if self.pair_list is not None
else self.pair_list
)
self.partial_charge = (
self.partial_charge.to(device)
if self.partial_charge is not None
else self.partial_charge
)
if dtype:
self.positions = self.positions.to(dtype)
return self

def __post_init__(self):
# Set dtype and convert units if necessary
self.atomic_numbers = self.atomic_numbers.to(torch.int32)

self.partial_charge = (
self.atomic_numbers.to(torch.int32) if self.partial_charge else None
)
self.atomic_subsystem_indices = self.atomic_subsystem_indices.to(torch.int32)
self.total_charge = self.total_charge.to(torch.int32)

Expand Down Expand Up @@ -673,9 +686,9 @@ def _from_hdf5(self) -> None:

if all(property_found):
# we want to exclude conformers with NaN values for any property of interest
configs_nan_by_prop: Dict[
str, np.ndarray
] = OrderedDict() # ndarray.size (n_configs, )
configs_nan_by_prop: Dict[str, np.ndarray] = (
OrderedDict()
) # ndarray.size (n_configs, )
for value in list(series_mol_data.keys()) + list(
series_atom_data.keys()
):
Expand Down
4 changes: 2 additions & 2 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Tuple
from .models import InputPreparation, BaseNetwork, CoreNetwork
from .models import ComputeInteractingAtomPairs, BaseNetwork, CoreNetwork

import torch
from loguru import logger as log
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(
# number of elements in ANI2x
self.num_species = 7

log.debug("Initializing ANI model.")
log.debug("Initializing the ANI2x architecture.")
super().__init__()

# Initialize representation block
Expand Down
20 changes: 10 additions & 10 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def generate_model(
raise NotImplementedError(f"Unsupported 'use' value: {use}")


class InputPreparation(torch.nn.Module):
class ComputeInteractingAtomPairs(torch.nn.Module):
def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True):
"""
A module for preparing input data, including the calculation of pair lists, distances (d_ij), and displacement vectors (r_ij) for molecular simulations.
Expand Down Expand Up @@ -891,7 +891,7 @@ def __init__(
self,
*,
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic,
dataset_statistic: Optional[Dict[str, float]],
cutoff: unit.Quantity,
):
"""
Expand All @@ -916,7 +916,7 @@ def __init__(
raise RuntimeError(
"The only_unique_pairs attribute is not set in the child class. Please set it to True or False before calling super().__init__."
)
self.input_preparation = InputPreparation(
self.compute_interacting_pairs = ComputeInteractingAtomPairs(
cutoff=_convert(cutoff), only_unique_pairs=self.only_unique_pairs
)

Expand Down Expand Up @@ -962,15 +962,15 @@ def load_state_dict(

super().load_state_dict(filtered_state_dict, strict=strict, assign=assign)

def prepare_input(self, data):
def prepare_pairwise_properties(self, data):

self.input_preparation._input_checks(data)
return self.input_preparation.prepare_inputs(data)
self.compute_interacting_pairs._input_checks(data)
return self.compute_interacting_pairs.prepare_inputs(data)

def compute(self, data, core_input):
return self.core_module(data, core_input)

def forward(self, data: NNPInput):
def forward(self, input_data: NNPInput):
"""
Executes the forward pass of the model.
This method performs input checks, prepares the inputs,
Expand All @@ -987,10 +987,10 @@ def forward(self, data: NNPInput):
The outputs computed by the core network.
"""

# perform input checks
core_input = self.prepare_input(data)
# compute all interacting pairs with distances
pairwise_properties = self.prepare_pairwise_properties(input_data)
# prepare the input for the forward pass
output = self.compute(data, core_input)
output = self.compute(input_data, pairwise_properties)
# perform postprocessing operations
processed_output = self.postprocessing(output)
return processed_output
Expand Down
57 changes: 28 additions & 29 deletions modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import torch.nn.functional as F
from loguru import logger as log
from openff.units import unit
from .models import InputPreparation, NNPInput, BaseNetwork, CoreNetwork
from .models import NNPInput, BaseNetwork, CoreNetwork

from .utils import Dense
from typing import List

if TYPE_CHECKING:
from .models import PairListOutputs
Expand Down Expand Up @@ -82,52 +83,54 @@ class PaiNNCore(CoreNetwork):

def __init__(
self,
max_Z: int = 100,
number_of_atom_features: int = 64,
number_of_radial_basis_functions: int = 16,
cutoff: unit.Quantity = 5 * unit.angstrom,
number_of_interaction_modules: int = 2,
shared_interactions: bool = False,
shared_filters: bool = False,
featurization_config: Dict[str, Union[List[str], int]],
number_of_radial_basis_functions: int,
cutoff: unit.Quantity,
number_of_interaction_modules: int,
shared_interactions: bool,
shared_filters: bool,
epsilon: float = 1e-8,
):
log.debug("Initializing PaiNN model.")
log.debug("Initializing the PaiNN architecture.")
super().__init__()

self.number_of_interaction_modules = number_of_interaction_modules
self.number_of_atom_features = number_of_atom_features
self.shared_filters = shared_filters

# embedding
from modelforge.potential.utils import Embedding

self.embedding_module = Embedding(max_Z, number_of_atom_features)
# featurize the atomic input
from modelforge.potential.utils import FeaturizeInput

self.featurize_input = FeaturizeInput(featurization_config)
number_of_per_atom_features = featurization_config[
"number_of_per_atom_features"
]
# initialize representation block
self.representation_module = PaiNNRepresentation(
cutoff,
number_of_radial_basis_functions,
number_of_interaction_modules,
number_of_atom_features,
number_of_per_atom_features,
shared_filters,
)

# initialize the interaction and mixing networks
self.interaction_modules = nn.ModuleList(
PaiNNInteraction(number_of_atom_features, activation=F.silu)
PaiNNInteraction(number_of_per_atom_features, activation=F.silu)
for _ in range(number_of_interaction_modules)
)
self.mixing_modules = nn.ModuleList(
PaiNNMixing(number_of_atom_features, activation=F.silu, epsilon=epsilon)
PaiNNMixing(number_of_per_atom_features, activation=F.silu, epsilon=epsilon)
for _ in range(number_of_interaction_modules)
)

self.energy_layer = nn.Sequential(
Dense(
number_of_atom_features, number_of_atom_features, activation=nn.ReLU()
number_of_per_atom_features,
number_of_per_atom_features,
activation=nn.ReLU(),
),
Dense(
number_of_atom_features,
number_of_per_atom_features,
1,
),
)
Expand All @@ -148,10 +151,8 @@ def _model_specific_input_preparation(
atomic_numbers=data.atomic_numbers,
atomic_subsystem_indices=data.atomic_subsystem_indices,
total_charge=data.total_charge,
atomic_embedding=self.embedding_module(
data.atomic_numbers
), # atom embedding
)
atomic_embedding=self.featurize_input(data),
) # add per-atom properties and embedding

return nnp_input

Expand Down Expand Up @@ -488,15 +489,14 @@ def forward(self, q: torch.Tensor, mu: torch.Tensor):
return q, mu


from .models import InputPreparation, NNPInput, BaseNetwork
from .models import ComputeInteractingAtomPairs, NNPInput, BaseNetwork
from typing import List


class PaiNN(BaseNetwork):
def __init__(
self,
max_Z: int,
number_of_atom_features: int,
featurization: Dict[str, Union[List[str], int]],
number_of_radial_basis_functions: int,
cutoff: Union[unit.Quantity, str],
number_of_interaction_modules: int,
Expand All @@ -518,8 +518,7 @@ def __init__(
)

self.core_module = PaiNNCore(
max_Z=max_Z,
number_of_atom_features=number_of_atom_features,
featurization_config=featurization,
number_of_radial_basis_functions=number_of_radial_basis_functions,
cutoff=_convert(cutoff),
number_of_interaction_modules=number_of_interaction_modules,
Expand All @@ -538,7 +537,7 @@ def _config_prior(self):
from modelforge.potential.utils import shared_config_prior

prior = {
"number_of_atom_features": tune.randint(2, 256),
"number_of_per_atom_features": tune.randint(2, 256),
"number_of_interaction_modules": tune.randint(1, 5),
"cutoff": tune.uniform(5, 10),
"number_of_radial_basis_functions": tune.randint(8, 32),
Expand Down
Loading
Loading