Skip to content

Commit

Permalink
Support unyt_array in __eq__/clone for PotentialExpression (#660
Browse files Browse the repository at this point in the history
)

* Support unyt_array in `__eq__`/`clone` for `PotentialExpression`

* WIP-Fix docstring; add more tests

* Additional tests; more documentation

* WIP- Move method to test dict equality inmodule
  • Loading branch information
umesh-timalsina authored May 25, 2022
1 parent 1b8eda1 commit a6255a4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
50 changes: 49 additions & 1 deletion gmso/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unyt as u

from gmso.tests.base_test import BaseTest
from gmso.utils.expression import PotentialExpression
from gmso.utils.expression import PotentialExpression, _are_equal_parameters


class TestExpression(BaseTest):
Expand Down Expand Up @@ -232,3 +232,51 @@ def test_clone(self):
)

assert expr == expr_clone

def test_clone_with_unyt_arrays(self):
expression = PotentialExpression(
expression="x**2 + y**2 + 2*x*y*theta",
independent_variables="theta",
parameters={
"x": [2.0, 4.5] * u.nm,
"y": [3.4, 4.5] * u.kcal / u.mol,
},
)

expression_clone = expression.clone()
assert expression_clone == expression

def test_expression_equality_different_params(self):
expr1 = PotentialExpression(
independent_variables="r",
parameters={"a": 2.0 * u.nm, "b": 3.0 * u.nm},
expression="a+r*b",
)

expr2 = PotentialExpression(
independent_variables="r",
parameters={"c": 2.0 * u.nm, "d": 3.0 * u.nm},
expression="c+r*d",
)

assert expr1 != expr2

def test_expression_equality_same_params_different_values(self):
expr1 = PotentialExpression(
independent_variables="r",
parameters={"a": 2.0 * u.nm, "b": 3.0 * u.nm},
expression="a+r*b",
)

expr2 = PotentialExpression(
independent_variables="r",
parameters={"a": 2.0 * u.nm, "b": 3.5 * u.nm},
expression="a+r*b",
)

assert expr1 != expr2

def test_are_equal_parameters(self):
u1 = {"a": 2.0 * u.nm, "b": 3.5 * u.nm}
u2 = {"c": 2.0 * u.nm, "d": 3.5 * u.nm}
assert _are_equal_parameters(u1, u2) is False
23 changes: 21 additions & 2 deletions gmso/utils/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,28 @@
import unyt as u

from gmso.utils.decorators import register_pydantic_json
from gmso.utils.misc import unyt_to_hashable

__all__ = ["PotentialExpression"]


def _are_equal_parameters(u1, u2):
"""Compare two parameters of unyt quantities/arrays.
This method compares two dictionaries (`u1` and `u2`) of
`unyt_quantities` and returns True if:
* u1 and u2 have the exact same key set
* for each key, the value in u1 and u2 have the same unyt quantity
"""
if u1.keys() != u2.keys():
return False
else:
for k, v in u1.items():
if not u.allclose_units(v, u2[k]):
return False

return True


@register_pydantic_json(method="json")
class PotentialExpression:
"""A general Expression class with parameters.
Expand Down Expand Up @@ -258,7 +275,7 @@ def __eq__(self, other):
return (
self.expression == other.expression
and self.independent_variables == other.independent_variables
and self.parameters == other.parameters
and _are_equal_parameters(self.parameters, other.parameters)
)

@staticmethod
Expand Down Expand Up @@ -295,6 +312,8 @@ def clone(self):
deepcopy(self._independent_variables),
{
k: u.unyt_quantity(v.value, v.units)
if v.value.shape == ()
else u.unyt_array(v.value, v.units)
for k, v in self._parameters.items()
}
if self._is_parametric
Expand Down

0 comments on commit a6255a4

Please sign in to comment.