Skip to content

Commit

Permalink
Fix merge conflicts and update XML writeout testing from main branch,…
Browse files Browse the repository at this point in the history
… PR #246
  • Loading branch information
CalCraven committed Jul 20, 2022
2 parents d95fd5b + cc3f8bb commit d8f62e4
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: trailing-whitespace
exclude: 'setup.cfg'
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 22.6.0
hooks:
- id: black
args: [--line-length=80]
Expand Down
6 changes: 6 additions & 0 deletions gmso/core/atom_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def __eq__(self, other):
and self.description == other.description
)

def _etree_attrib(self):
attrib = super()._etree_attrib()
if self.overrides == set():
attrib.pop("overrides")
return attrib

def __repr__(self):
"""Return a formatted representation of the atom type."""
desc = (
Expand Down
120 changes: 120 additions & 0 deletions gmso/core/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import warnings
from collections import ChainMap
from pathlib import Path
from typing import Iterable

from lxml import etree
Expand Down Expand Up @@ -476,6 +477,125 @@ def __str__(self):
"""Return a string representation of the ForceField."""
return f"<ForceField {self.name}, id: {id(self)}>"

def xml(self, filename, overwrite=False):
"""Get an lxml ElementTree representation of this ForceField
Parameters
----------
filename: Union[str, pathlib.Path], default=None
The filename to write the XML file to
overwrite: bool, default=False
If True, overwrite an existing file if it exists
"""
ff_el = etree.Element(
"ForceField", attrib={"name": self.name, "version": self.version}
)

metadata = etree.SubElement(ff_el, "FFMetaData")
if self.scaling_factors.get("electrostatics14Scale"):
metadata.attrib["electrostatics14Scale"] = str(
self.scaling_factors.get("electrostatics14Scale")
)
if self.scaling_factors.get("nonBonded14Scale"):
metadata.attrib["nonBonded14Scale"] = str(
self.scaling_factors.get("nonBonded14Scale")
)

# ToDo: ParameterUnitsDefintions and DefaultUnits

etree.SubElement(
metadata,
"Units",
attrib={
"energy": "K*kb",
"distance": "nm",
"mass": "amu",
"charge": "coulomb",
},
)

at_groups = self.group_atom_types_by_expression()
for expr, atom_types in at_groups.items():
atypes = etree.SubElement(
ff_el, "AtomTypes", attrib={"expression": expr}
)
params_units_def = None
for atom_type in atom_types:
if params_units_def is None:
params_units_def = {}
for param, value in atom_type.parameters.items():
params_units_def[param] = value.units
etree.SubElement(
atypes,
"ParametersUnitDef",
attrib={
"parameter": param,
"unit": str(value.units),
},
)

atypes.append(atom_type.etree(units=params_units_def))

bond_types_groups = self.group_bond_types_by_expression()
angle_types_groups = self.group_angle_types_by_expression()
dihedral_types_groups = self.group_dihedral_types_by_expression()
improper_types_groups = self.group_improper_types_by_expression()

for tag, potential_group in [
("BondTypes", bond_types_groups),
("AngleTypes", angle_types_groups),
("DihedralTypes", dihedral_types_groups),
("ImproperTypes", improper_types_groups),
]:
for expr, potentials in potential_group.items():
potential_group = etree.SubElement(
ff_el, tag, attrib={"expression": expr}
)
params_units_def = None
for potential in potentials:
if params_units_def is None:
params_units_def = {}
for param, value in potential.parameters.items():
params_units_def[param] = value.units
etree.SubElement(
potential_group,
"ParametersUnitDef",
attrib={
"parameter": param,
"unit": str(value.units),
},
)

potential_group.append(potential.etree(params_units_def))

ff_etree = etree.ElementTree(element=ff_el)

if not isinstance(filename, Path):
filename = Path(filename)

if filename.suffix != ".xml":
from gmso.exceptions import ForceFieldError

raise ForceFieldError(
f"The filename {str(filename)} is not an XML file. "
f"Please provide filename with .xml extension"
)

if not overwrite and filename.exists():
raise FileExistsError(
f"File {filename} already exists. Consider "
f"using overwrite=True if you want to overwrite "
f"the existing file."
)

ff_etree.write(
str(filename),
pretty_print=True,
xml_declaration=True,
encoding="utf-8",
)

@classmethod
def from_xml(cls, xmls_or_etrees, strict=True, greedy=True):
"""Create a gmso.Forcefield object from XML File(s).
Expand Down
64 changes: 64 additions & 0 deletions gmso/core/parametric_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from typing import Any, Union

import unyt as u
from lxml import etree
from pydantic import Field, validator

from gmso.abc.abstract_potential import AbstractPotential
from gmso.utils.expression import PotentialExpression
from gmso.utils.misc import get_xml_representation


class ParametricPotential(AbstractPotential):
Expand Down Expand Up @@ -196,6 +199,67 @@ def clone(self, fast_copy=False):
**kwargs,
)

def _etree_attrib(self):
"""Return the XML equivalent representation of this ParametricPotential"""
attrib = {
key: get_xml_representation(value)
for key, value in self.dict(
by_alias=True,
exclude_none=True,
exclude={
"topology_",
"set_ref_",
"member_types_",
"potential_expression_",
"tags_",
},
).items()
if value != ""
}

return attrib

def etree(self, units=None):
"""Return an lxml.ElementTree for the parametric potential adhering to gmso XML schema"""

attrib = self._etree_attrib()

if hasattr(self, "member_types") and hasattr(self, "member_classes"):
if self.member_types:
iterating_attribute = self.member_types
prefix = "type"
elif self.member_classes:
iterating_attribute = self.member_classes
prefix = "class"
else:
raise GMSOError(
f"Cannot convert {self.__class__.__name__} into an XML."
f"Please specify member_classes or member_types attribute."
)
for idx, value in enumerate(iterating_attribute):
attrib[f"{prefix}{idx+1}"] = str(value)

xml_element = etree.Element(self.__class__.__name__, attrib=attrib)
params = etree.SubElement(xml_element, "Parameters")

for key, value in self.parameters.items():
value_unit = None
if units is not None:
value_unit = units[key]

etree.SubElement(
params,
"Parameter",
attrib={
"name": key,
"value": get_xml_representation(
value.in_units(value_unit) if value_unit else value
),
},
)

return xml_element

@classmethod
def from_template(cls, potential_template, parameters, name=None, **kwargs):
"""Create a potential object from the potential_template.
Expand Down
55 changes: 55 additions & 0 deletions gmso/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,61 @@ def get_index(self, member):

return index

def _reindex_connection_types(self, ref):
"""Re-generate the indices of the connection types in the topology."""
if ref not in self._index_refs:
raise GMSOError(
f"cannot reindex {ref}. It should be one of "
f"{ANGLE_TYPE_DICT}, {BOND_TYPE_DICT}, "
f"{ANGLE_TYPE_DICT}, {DIHEDRAL_TYPE_DICT}, {IMPROPER_TYPE_DICT},"
f"{PAIRPOTENTIAL_TYPE_DICT}"
)
for i, ref_member in enumerate(self._set_refs[ref].keys()):
self._index_refs[ref][ref_member] = i

def get_forcefield(self):
"""Get an instance of gmso.ForceField out of this topology
Raises
------
GMSOError
If the topology is untyped
"""
if not self.is_typed():
raise GMSOError(
"Cannot create a ForceField from an untyped topology."
)
else:
from gmso import ForceField
from gmso.utils._constants import FF_TOKENS_SEPARATOR

ff = ForceField()
ff.name = self.name + "_ForceField"
ff.scaling_factors = {
"electrostatics14Scale": self.scaling_factors[1,2],
"nonBonded14Scale": self.scaling_factors[0,2],
}
for atom_type in self.atom_types:
ff.atom_types[atom_type.name] = atom_type.copy(
deep=True, exclude={"topology_", "set_ref_"}
)

ff_conn_types = {
BondType: ff.bond_types,
AngleType: ff.angle_types,
DihedralType: ff.dihedral_types,
ImproperType: ff.improper_types,
}

for connection_type in self.connection_types:
ff_conn_types[type(connection_type)][
FF_TOKENS_SEPARATOR.join(connection_type.member_types)
] = connection_type.copy(
deep=True, exclude={"topology_", "set_ref_"}
)

return ff

def iter_sites(self, key, value):
"""Iterate through this topology's sites based on certain attribute and their values.
Expand Down
21 changes: 21 additions & 0 deletions gmso/tests/test_forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gmso.core.forcefield import ForceField
from gmso.core.improper_type import ImproperType
from gmso.exceptions import (
ForceFieldError,
ForceFieldParseError,
MissingAtomTypesError,
MissingPotentialError,
Expand Down Expand Up @@ -613,3 +614,23 @@ def test_forcefield_get_impropers_combinations(self):
)
assert imp1.name == imp2.name
assert imp1 is imp2

def test_write_xml(self, opls_ethane_foyer):
opls_ethane_foyer.xml("test_xml_writer.xml")
reloaded_xml = ForceField("test_xml_writer.xml")
get_names = lambda ff, param: [
typed for typed in getattr(ff, param).keys()
]
for param in [
"atom_types",
"bond_types",
"angle_types",
"dihedral_types",
]:
assert get_names(opls_ethane_foyer, param) == get_names(
reloaded_xml, param
)

def test_write_not_xml(self, opls_ethane_foyer):
with pytest.raises(ForceFieldError):
opls_ethane_foyer.xml("bad_path")
8 changes: 8 additions & 0 deletions gmso/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,11 @@ def test_iter_sites_by_molecule(self, labeled_top):
for molecule_name in molecule_names:
for site in labeled_top.iter_sites_by_molecule(molecule_name):
assert site.molecule.name == molecule_name

def test_write_forcefield(self, typed_water_system):
forcefield = typed_water_system.get_forcefield()
assert "opls_111" in forcefield.atom_types
assert "opls_112" in forcefield.atom_types
top = Topology()
with pytest.raises(GMSOError):
top.get_forcefield()
10 changes: 10 additions & 0 deletions gmso/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,13 @@ def mask_with(iterable, window_size=1, mask="*"):

idx += 1
yield to_yield


def get_xml_representation(value):
"""Given a value, get its XML representation."""
if isinstance(value, u.unyt_quantity):
return str(value.value)
elif isinstance(value, set):
return ",".join(value)
else:
return str(value)

0 comments on commit d8f62e4

Please sign in to comment.