Skip to content

Commit

Permalink
override Lobsterin.__contains__ to fix on py312
Browse files Browse the repository at this point in the history
  • Loading branch information
JaGeo authored and janosh committed May 1, 2024
1 parent da1f7c5 commit 7779aeb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
38 changes: 13 additions & 25 deletions pymatgen/io/lobster/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,7 @@
)


class DotDict(UserDict):
# copied for python 3.12 compat from
# https://github.com/python/cpython/issues/105524#issuecomment-1610750842
def __getitem__(self, key):
subitem = self.data
for subkey in key.split("."):
try:
subitem = subitem[subkey]
except KeyError:
raise KeyError(f"`{key}` not found in configuration") from None
return subitem

def __contains__(self, key):
subitem = self.data
for subkey in key.split("."):
try:
subitem = subitem[subkey]
except KeyError:
return False
return True


class Lobsterin(DotDict, MSONable):
class Lobsterin(UserDict, MSONable):
"""
This class can handle and generate lobsterin files
Furthermore, it can also modify INCAR files for lobster, generate KPOINT files for fatband calculations in Lobster,
Expand Down Expand Up @@ -154,7 +132,7 @@ def __init__(self, settingsdict: dict):
raise KeyError("There are duplicates for the keywords!")
self.update(settingsdict)

def __setitem__(self, key, val):
def __setitem__(self, key, val) -> None:
"""
Add parameter-val pair to Lobsterin. Warns if parameter is not in list of
valid lobsterin tags. Also cleans the parameter and val by stripping
Expand All @@ -168,7 +146,7 @@ def __setitem__(self, key, val):

super().__setitem__(new_key, val.strip() if isinstance(val, str) else val)

def __getitem__(self, key):
def __getitem__(self, key) -> Any:
"""Implements getitem from dict to avoid problems with cases."""
normalized_key = next((k for k in self if key.strip().lower() == k.lower()), key)

Expand All @@ -178,6 +156,16 @@ def __getitem__(self, key):

return self.data[normalized_key]

def __contains__(self, key) -> bool:
"""Implements getitem from dict to avoid problems with cases."""
normalized_key = next((k for k in self if key.strip().lower() == k.lower()), key)

key_is_unknown = normalized_key.lower() not in map(str.lower, Lobsterin.AVAILABLE_KEYWORDS)
if key_is_unknown or normalized_key not in self.data:
return False

return True

def __delitem__(self, key):
new_key = next((key_here for key_here in self if key.strip().lower() == key_here.lower()), key)

Expand Down
18 changes: 11 additions & 7 deletions tests/io/lobster/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,7 +1614,7 @@ def test_initialize_from_dict(self):
lobsterin2 = Lobsterin({"cohpstartenergy": -15.0, "cohpstartEnergy": -20.0})
lobsterin2 = Lobsterin({"cohpstartenergy": -15.0})
# can only calculate nbands if basis functions are provided
with pytest.raises(ValueError, match="No basis functions are provided. The program cannot calculate nbands"):
with pytest.raises(ValueError, match="No basis functions are provided. The program cannot calculate nbands."):
lobsterin2._get_nbands(structure=Structure.from_file(f"{VASP_IN_DIR}/POSCAR_Fe3O4"))

def test_standard_settings(self):
Expand Down Expand Up @@ -2421,13 +2421,17 @@ def test_write_file(self):
filename=f"{TEST_DIR}/LCAOWaveFunctionAfterLSO1PlotOfSpin1Kpoint1band1.gz",
structure=Structure.from_file(f"{TEST_DIR}/POSCAR_O.gz"),
)
wave1.write_file(filename=f"{self.tmp_path}/wavecar_test.vasp", part="real")
assert os.path.isfile("wavecar_test.vasp")
real_wavecar_path = f"{self.tmp_path}/real-wavecar.vasp"
wave1.write_file(filename=real_wavecar_path, part="real")
assert os.path.isfile(real_wavecar_path)

wave1.write_file(filename=f"{self.tmp_path}/wavecar_test.vasp", part="imaginary")
assert os.path.isfile("wavecar_test.vasp")
wave1.write_file(filename=f"{self.tmp_path}/density.vasp", part="density")
assert os.path.isfile("density.vasp")
imag_wavecar_path = f"{self.tmp_path}/imaginary-wavecar.vasp"
wave1.write_file(filename=imag_wavecar_path, part="imaginary")
assert os.path.isfile(imag_wavecar_path)

density_wavecar_path = f"{self.tmp_path}/density-wavecar.vasp"
wave1.write_file(filename=density_wavecar_path, part="density")
assert os.path.isfile(density_wavecar_path)


class TestSitePotentials(PymatgenTest):
Expand Down

0 comments on commit 7779aeb

Please sign in to comment.