Skip to content

JDFTx Inputs - boundary value checking #4410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 111 additions & 7 deletions src/pymatgen/io/jdftx/generic_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def _validate_repeat(self, tag: str, value: Any) -> None:
if not isinstance(value, list):
raise TypeError(f"The '{tag}' tag can repeat but is not a list: '{value}'")

def validate_value_bounds(
self,
tag: str,
value: Any,
) -> tuple[bool, str]:
return True, ""

@abstractmethod
def read(self, tag: str, value_str: str) -> Any:
"""Read and parse the value string for this tag.
Expand Down Expand Up @@ -365,7 +372,62 @@ def get_token_len(self) -> int:


@dataclass
class IntTag(AbstractTag):
class AbstractNumericTag(AbstractTag):
"""Abstract base class for numeric tags."""

lb: float | None = None # lower bound
ub: float | None = None # upper bound
lb_incl: bool = True # lower bound inclusive
ub_incl: bool = True # upper bound inclusive

def val_is_within_bounds(self, value: float) -> bool:
"""Check if the value is within the bounds.

Args:
value (float | int): The value to check.

Returns:
bool: True if the value is within the bounds, False otherwise.
"""
good = True
if self.lb is not None:
good = good and value >= self.lb if self.lb_incl else good and value > self.lb
if self.ub is not None:
good = good and value <= self.ub if self.ub_incl else good and value < self.ub
return good

def get_invalid_value_error_str(self, tag: str, value: float) -> str:
"""Raise a ValueError for the invalid value.

Args:
tag (str): The tag to raise the ValueError for.
value (float | int): The value to raise the ValueError for.
"""
err_str = f"Value '{value}' for tag '{tag}' is not within bounds"
if self.ub is not None:
err_str += f" {self.ub} >"
if self.ub_incl:
err_str += "="
err_str += " x "
if self.lb is not None:
err_str += ">"
if self.lb_incl:
err_str += "="
err_str += f" {self.lb}"
return err_str

def validate_value_bounds(
self,
tag: str,
value: Any,
) -> tuple[bool, str]:
if not self.val_is_within_bounds(value):
return False, self.get_invalid_value_error_str(tag, value)
return True, ""


@dataclass
class IntTag(AbstractNumericTag):
"""Tag for integer values in JDFTx input files.

Tag for integer values in JDFTx input files.
Expand Down Expand Up @@ -411,6 +473,8 @@ def write(self, tag: str, value: Any) -> str:
Returns:
str: The tag and its value as a string.
"""
if not self.val_is_within_bounds(value):
return ""
return self._write(tag, value)

def get_token_len(self) -> int:
Expand All @@ -423,14 +487,13 @@ def get_token_len(self) -> int:


@dataclass
class FloatTag(AbstractTag):
class FloatTag(AbstractNumericTag):
"""Tag for float values in JDFTx input files.

Tag for float values in JDFTx input files.
"""

prec: int | None = None
minval: float | None = None

def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = False) -> tuple[str, bool, Any]:
"""Validate the type of the value for this tag.
Expand Down Expand Up @@ -473,10 +536,7 @@ def write(self, tag: str, value: Any) -> str:
Returns:
str: The tag and its value as a string.
"""
# Returning an empty string instead of raising an error as value == self.minval
# will cause JDFTx to throw an error, but the internal infile dumps the value as
# as the minval if not set by the user.
if (self.minval is not None) and (not value > self.minval):
if not self.val_is_within_bounds(value):
return ""
# pre-convert to string: self.prec+3 is minimum room for:
# - sign, 1 integer left of decimal, decimal, and precision.
Expand Down Expand Up @@ -598,6 +658,50 @@ def _validate_single_entry(
types_checks.append(check)
return tags_checked, types_checks, updated_value

def _validate_bounds_single_entry(self, value: dict | list[dict]) -> tuple[list[str], list[bool], list[str]]:
if not isinstance(value, dict):
raise TypeError(f"The value '{value}' (of type {type(value)}) must be a dict for this TagContainer!")
tags_checked: list[str] = []
types_checks: list[bool] = []
reported_errors: list[str] = []
for subtag, subtag_value in value.items():
subtag_object = self.subtags[subtag]
check, err_str = subtag_object.validate_value_bounds(subtag, subtag_value)
tags_checked.append(subtag)
types_checks.append(check)
reported_errors.append(err_str)
return tags_checked, types_checks, reported_errors

def validate_value_bounds(self, tag: str, value: Any) -> tuple[bool, str]:
value_dict = value
if self.can_repeat:
self._validate_repeat(tag, value_dict)
results = [self._validate_bounds_single_entry(x) for x in value_dict]
tags_list_list: list[list[str]] = [result[0] for result in results]
is_valids_list_list: list[list[bool]] = [result[1] for result in results]
reported_errors_list: list[list[str]] = [result[2] for result in results]
is_valid_out = all(all(x) for x in is_valids_list_list)
errors_out = ",".join([",".join(x) for x in reported_errors_list])
if not is_valid_out:
warnmsg = "Invalid value(s) found for: "
for i, x in enumerate(is_valids_list_list):
if not all(x):
for j, y in enumerate(x):
if not y:
warnmsg += f"{tags_list_list[i][j]} ({reported_errors_list[i][j]}) "
warnings.warn(warnmsg, stacklevel=2)
else:
tags, is_valids, reported_errors = self._validate_bounds_single_entry(value_dict)
is_valid_out = all(is_valids)
errors_out = ",".join(reported_errors)
if not is_valid_out:
warnmsg = "Invalid value(s) found for: "
for ii, xx in enumerate(is_valids):
if not xx:
warnmsg += f"{tags[ii]} ({reported_errors[ii]}) "
warnings.warn(warnmsg, stacklevel=2)
return is_valid_out, f"{tag}: {errors_out}"

def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = False) -> tuple[str, bool, Any]:
"""Validate the type of the value for this tag.

Expand Down
36 changes: 32 additions & 4 deletions src/pymatgen/io/jdftx/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class JDFTXInfile(dict, MSONable):
Essentially a dictionary with some helper functions.
"""

path_parent: str | None = None # Only gets a value if JDFTXInfile is initializedf with from_file
path_parent: str | None = None # Only gets a value if JDFTXInfile is initialized with from_file

def __init__(self, params: dict[str, Any] | None = None) -> None:
"""
Expand Down Expand Up @@ -147,7 +147,7 @@ def _from_dict(cls, dct: dict[str, Any]) -> JDFTXInfile:
return cls.get_list_representation(temp)

@classmethod
def from_dict(cls, d: dict[str, Any]) -> JDFTXInfile:
def from_dict(cls, d: dict[str, Any], validate_value_boundaries=True) -> JDFTXInfile:
"""Create JDFTXInfile from a dictionary.

Args:
Expand All @@ -160,6 +160,8 @@ def from_dict(cls, d: dict[str, Any]) -> JDFTXInfile:
for k, v in d.items():
if k not in ("@module", "@class"):
instance[k] = v
if validate_value_boundaries:
instance.validate_boundaries()
return instance

def copy(self) -> JDFTXInfile:
Expand Down Expand Up @@ -213,6 +215,7 @@ def from_file(
dont_require_structure: bool = False,
sort_tags: bool = True,
assign_path_parent: bool = True,
validate_value_boundaries: bool = True,
) -> JDFTXInfile:
"""Read a JDFTXInfile object from a file.

Expand All @@ -235,6 +238,7 @@ def from_file(
dont_require_structure=dont_require_structure,
sort_tags=sort_tags,
path_parent=path_parent,
validate_value_boundaries=validate_value_boundaries,
)

@staticmethod
Expand Down Expand Up @@ -373,6 +377,7 @@ def from_str(
dont_require_structure: bool = False,
sort_tags: bool = True,
path_parent: Path | None = None,
validate_value_boundaries: bool = True,
) -> JDFTXInfile:
"""Read a JDFTXInfile object from a string.

Expand All @@ -382,6 +387,7 @@ def from_str(
sort_tags (bool, optional): Whether to sort the tags. Defaults to True.
path_parent (Path, optional): Path to the parent directory of the input file for include tags.
Defaults to None.
validate_value_boundaries (bool, optional): Whether to validate the value boundaries. Defaults to True.

Returns:
JDFTXInfile: The created JDFTXInfile object.
Expand Down Expand Up @@ -416,7 +422,10 @@ def from_str(
raise ValueError("This input file is missing required structure tags")
if sort_tags:
params = {tag: params[tag] for tag in __TAG_LIST__ if tag in params}
return cls(params)
instance = cls(params)
if validate_value_boundaries:
instance.validate_boundaries()
return instance

@classmethod
def to_jdftxstructure(cls, jdftxinfile: JDFTXInfile, sort_structure: bool = False) -> JDFTXStructure:
Expand Down Expand Up @@ -573,6 +582,25 @@ def validate_tags(
warnmsg += "(Check earlier warnings for more details)\n"
warnings.warn(warnmsg, stacklevel=2)

def validate_boundaries(self) -> None:
"""Validate the boundaries of the JDFTXInfile.

Validate the boundaries of the JDFTXInfile. This is a placeholder for future functionality.
"""
error_strs: list[str] = []
for tag in self:
tag_object = get_tag_object(tag)
is_valid, error_str = tag_object.validate_value_bounds(tag, self[tag])
if not is_valid:
error_strs.append(error_str)
if len(error_strs) > 0:
err_cat = "\n".join(error_strs)
raise ValueError(
f"The following boundary errors were found in the JDFTXInfile:\n{err_cat}\n"
"\n Hint - if you are reading from a JDFTX out file, you need to set validate_value_boundaries "
"to False, as JDFTx will dump values at non-inclusive boundaries (ie 0.0 for values strictly > 0.0)."
)

def strip_structure_tags(self) -> None:
"""Strip all structural tags from the JDFTXInfile.

Expand Down Expand Up @@ -614,7 +642,7 @@ def __setitem__(self, key: str, value: Any) -> None:
if self._is_numeric(value):
value = str(value)
if not tag_object.can_repeat:
value = [value]
value = [value] # Shortcut to avoid writing a separate block for non-repeatable tags
for v in value:
processed_value = tag_object.read(key, v) if isinstance(v, str) else v
params = self._store_value(params, tag_object, key, processed_value)
Expand Down
16 changes: 8 additions & 8 deletions src/pymatgen/io/jdftx/jdftxinfile_ref_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,11 +800,11 @@
"wolfeGradient": FloatTag(),
}
jdftxfluid_subtagdict = {
"epsBulk": FloatTag(minval=1.0),
"epsInf": FloatTag(),
"epsBulk": FloatTag(lb=0.0, lb_incl=False),
"epsInf": FloatTag(lb=1.0, lb_incl=True),
"epsLJ": FloatTag(),
"Nnorm": FloatTag(),
"pMol": FloatTag(),
"pMol": FloatTag(lb=0.0, lb_incl=True),
"poleEl": TagContainer(
can_repeat=True,
write_tagname=True,
Expand All @@ -814,13 +814,13 @@
"A0": FloatTag(write_tagname=False, optional=False),
},
),
"Pvap": FloatTag(minval=0.0),
"Pvap": FloatTag(lb=0.0, lb_incl=False),
"quad_nAlpha": FloatTag(),
"quad_nBeta": FloatTag(),
"quad_nGamma": FloatTag(),
"representation": TagContainer(subtags={"MuEps": FloatTag(), "Pomega": FloatTag(), "PsiAlpha": FloatTag()}),
"Res": FloatTag(minval=0.0),
"Rvdw": FloatTag(),
"Res": FloatTag(lb=0.0, lb_incl=False),
"Rvdw": FloatTag(lb=0.0, lb_incl=False),
"s2quadType": StrTag(
options=[
"10design60",
Expand All @@ -844,7 +844,7 @@
"Tetrahedron",
]
),
"sigmaBulk": FloatTag(minval=0.0),
"tauNuc": FloatTag(),
"sigmaBulk": FloatTag(lb=0.0, lb_incl=False),
"tauNuc": FloatTag(lb=0.0, lb_incl=False),
"translation": StrTag(options=["ConstantSpline", "Fourier", "LinearSpline"]),
}
4 changes: 3 additions & 1 deletion src/pymatgen/io/jdftx/jdftxoutfileslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ def _set_internal_infile(self, text: list[str]) -> None:
break
if end_line_idx is None:
raise ValueError("Calculation did not begin for this out file slice.")
self.infile = JDFTXInfile.from_str("\n".join(text[start_line_idx:end_line_idx]))
self.infile = JDFTXInfile.from_str(
"\n".join(text[start_line_idx:end_line_idx]), validate_value_boundaries=False
)
self.constant_lattice = True
if "lattice-minimize" in self.infile:
latsteps = self.infile["lattice-minimize"]["nIterations"]
Expand Down
44 changes: 44 additions & 0 deletions tests/io/jdftx/test_generic_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,47 @@ def test_multiformattagcontainer():
"Check your inputs and/or MASTER_TAG_LIST!"
with pytest.raises(ValueError, match=re.escape(err_str)):
mftg._determine_format_option(tag, value)


def test_boundary_checking():
# Check that non-numeric tag returns valid always
tag = "barbie"
value = "notanumber"
valtag = StrTag()
assert valtag.validate_value_bounds(tag, value)[0] is True
# Check that numeric tags can return False
value = 0.0
valtag = FloatTag(lb=1.0)
assert valtag.validate_value_bounds(tag, value)[0] is False
valtag = FloatTag(lb=0.0, lb_incl=False)
assert valtag.validate_value_bounds(tag, value)[0] is False
valtag = FloatTag(ub=-1.0)
assert valtag.validate_value_bounds(tag, value)[0] is False
valtag = FloatTag(ub=0.0, ub_incl=False)
assert valtag.validate_value_bounds(tag, value)[0] is False
# Check that numeric tags can return True
valtag = FloatTag(lb=0.0, lb_incl=True)
assert valtag.validate_value_bounds(tag, value)[0] is True
valtag = FloatTag(ub=0.0, ub_incl=True)
assert valtag.validate_value_bounds(tag, value)[0] is True
valtag = FloatTag(lb=-1.0, ub=1.0)
assert valtag.validate_value_bounds(tag, value)[0] is True
# Check functionality for tagcontainers
tagcontainer = TagContainer(
subtags={
"ken": FloatTag(lb=0.0, lb_incl=True),
"allan": StrTag(),
"skipper": FloatTag(ub=1.0, ub_incl=True, lb=-1.0, lb_incl=False),
},
)
valid, errors = tagcontainer.validate_value_bounds(tag, {"ken": -1.0, "allan": "notanumber", "skipper": 2.0})
assert valid is False
assert "allan" not in errors
assert "ken" in errors
assert "x >= 0.0" in errors
assert "skipper" in errors
assert "1.0 >= x > -1.0" in errors
# Make sure tags will never write a value that is out of bounds
valtag = FloatTag(lb=-1.0, ub=1.0)
assert len(valtag.write(tag, 0.0))
assert not len(valtag.write(tag, 2.0))