Skip to content

Commit

Permalink
drop Lobsterin inerheritance from UserDict, use simple dict instead a…
Browse files Browse the repository at this point in the history
…nd modify __getitem__ to get the salient __getitem__ behavior from UserDict
  • Loading branch information
janosh committed Apr 30, 2024
1 parent 75202b7 commit bc5323a
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 126 deletions.
12 changes: 6 additions & 6 deletions pymatgen/alchemy/transmuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def from_structures(cls, structures, transformations=None, extend_collection=0)
Returns:
StandardTransmuter
"""
trafo_struct = [TransformedStructure(s, []) for s in structures]
return cls(trafo_struct, transformations, extend_collection)
t_struct = [TransformedStructure(s, []) for s in structures]
return cls(t_struct, transformations, extend_collection)


class CifTransmuter(StandardTransmuter):
Expand Down Expand Up @@ -253,8 +253,8 @@ def __init__(self, cif_string, transformations=None, primitive=True, extend_coll
if read_data:
structure_data[-1].append(line)
for data in structure_data:
trafo_struct = TransformedStructure.from_cif_str("\n".join(data), [], primitive)
transformed_structures.append(trafo_struct)
t_struct = TransformedStructure.from_cif_str("\n".join(data), [], primitive)
transformed_structures.append(t_struct)
super().__init__(transformed_structures, transformations, extend_collection)

@classmethod
Expand Down Expand Up @@ -293,8 +293,8 @@ def __init__(self, poscar_string, transformations=None, extend_collection=False)
extend_collection: Whether to use more than one output structure
from one-to-many transformations.
"""
trafo_struct = TransformedStructure.from_poscar_str(poscar_string, [])
super().__init__([trafo_struct], transformations, extend_collection=extend_collection)
t_struct = TransformedStructure.from_poscar_str(poscar_string, [])
super().__init__([t_struct], transformations, extend_collection=extend_collection)

@classmethod
def from_filenames(cls, poscar_filenames, transformations=None, extend_collection=False) -> StandardTransmuter:
Expand Down
16 changes: 8 additions & 8 deletions pymatgen/analysis/structure_prediction/substitutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def pred_from_structures(
raise ValueError("the species in target_species are not allowed for the probability model you are using")

for permutation in itertools.permutations(target_species):
for s in structures_list:
for dct in structures_list:
# check if: species are in the domain,
# and the probability of subst. is above the threshold
els = s["structure"].elements
els = dct["structure"].elements
if (
len(els) == len(permutation)
and len(set(els) & set(self.get_allowed_species())) == len(els)
Expand All @@ -136,18 +136,18 @@ def pred_from_structures(

transf = SubstitutionTransformation(clean_subst)

if Substitutor._is_charge_balanced(transf.apply_transformation(s["structure"])):
ts = TransformedStructure(
s["structure"],
if Substitutor._is_charge_balanced(transf.apply_transformation(dct["structure"])):
t_struct = TransformedStructure(
dct["structure"],
[transf],
history=[{"source": s["id"]}],
history=[{"source": dct["id"]}],
other_parameters={
"type": "structure_prediction",
"proba": self._sp.cond_prob_list(permutation, els),
},
)
result.append(ts)
transmuter.append_transformed_structures([ts])
result.append(t_struct)
transmuter.append_transformed_structures([t_struct])

if remove_duplicates:
transmuter.apply_filter(RemoveDuplicatesFilter(symprec=self._symprec))
Expand Down
58 changes: 30 additions & 28 deletions pymatgen/io/lobster/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,16 @@ def __setitem__(self, key, val):

super().__setitem__(new_key, val.strip() if isinstance(val, str) else val)

def __getitem__(self, item):
def __getitem__(self, key):
"""Implements getitem from dict to avoid problems with cases."""
new_item = next((key_here for key_here in self if item.strip().lower() == key_here.lower()), item)
normalized_key = next((k for k in self if key.strip().lower() == k.lower()), key)

if new_item.lower() not in [element.lower() for element in Lobsterin.AVAILABLE_KEYWORDS]:
raise KeyError("Key is currently not available")
if normalized_key.lower() not in [element.lower() for element in Lobsterin.AVAILABLE_KEYWORDS]:
raise KeyError(f"{key=} is currently not available")

return super().__getitem__(new_item)
if normalized_key in self.data:
return self.data[normalized_key]
raise KeyError(normalized_key)

def __delitem__(self, key):
new_key = next((key_here for key_here in self if key.strip().lower() == key_here.lower()), key)
Expand Down Expand Up @@ -566,30 +568,30 @@ def from_file(cls, lobsterin: str) -> Self:
lobsterin_dict: dict[str, Any] = {}

for datum in data:
# Remove all comments
if not datum.startswith(("!", "#", "//")):
pattern = r"\b[^!#//]+" # exclude comments after commands
if matched_pattern := re.findall(pattern, datum):
raw_datum = matched_pattern[0].replace("\t", " ") # handle tab in between and end of command
key_word = raw_datum.strip().split(" ") # extract keyword
if len(key_word) > 1:
# check which type of keyword this is, handle accordingly
if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.LISTKEYWORDS]:
if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.FLOAT_KEYWORDS]:
if key_word[0].lower() not in lobsterin_dict:
lobsterin_dict[key_word[0].lower()] = " ".join(key_word[1:])
else:
raise ValueError(f"Same keyword {key_word[0].lower()} twice!")
elif key_word[0].lower() not in lobsterin_dict:
lobsterin_dict[key_word[0].lower()] = float(key_word[1])
else:
raise ValueError(f"Same keyword {key_word[0].lower()} twice!")
elif key_word[0].lower() not in lobsterin_dict:
lobsterin_dict[key_word[0].lower()] = [" ".join(key_word[1:])]
if datum.startswith(("!", "#", "//")):
continue # ignore comments
pattern = r"\b[^!#//]+" # exclude comments after commands
if matched_pattern := re.findall(pattern, datum):
raw_datum = matched_pattern[0].replace("\t", " ") # handle tab in between and end of command
key_word = raw_datum.strip().split(" ") # extract keyword
key = key_word[0].lower()
if len(key_word) > 1:
# check which type of keyword this is, handle accordingly
if key not in [datum2.lower() for datum2 in Lobsterin.LISTKEYWORDS]:
if key not in [datum2.lower() for datum2 in Lobsterin.FLOAT_KEYWORDS]:
if key in lobsterin_dict:
raise ValueError(f"Same keyword {key} twice!")
lobsterin_dict[key] = " ".join(key_word[1:])
elif key in lobsterin_dict:
raise ValueError(f"Same keyword {key} twice!")
else:
lobsterin_dict[key_word[0].lower()].append(" ".join(key_word[1:]))
elif len(key_word) > 0:
lobsterin_dict[key_word[0].lower()] = True
lobsterin_dict[key] = float("nan" if key_word[1].strip() == "None" else key_word[1])
elif key not in lobsterin_dict:
lobsterin_dict[key] = [" ".join(key_word[1:])]
else:
lobsterin_dict[key].append(" ".join(key_word[1:]))
elif len(key_word) > 0:
lobsterin_dict[key] = True

return cls(lobsterin_dict)

Expand Down
62 changes: 31 additions & 31 deletions tests/alchemy/test_materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def test_append_transformation(self):
[0, -2.2171384943, 3.1355090603],
]
struct = Structure(lattice, ["Si4+", "Si4+"], coords)
ts = TransformedStructure(struct, [])
ts.append_transformation(SupercellTransformation.from_scaling_factors(2, 1, 1))
alt = ts.append_transformation(
t_struct = TransformedStructure(struct, [])
t_struct.append_transformation(SupercellTransformation.from_scaling_factors(2, 1, 1))
alt = t_struct.append_transformation(
PartialRemoveSpecieTransformation("Si4+", 0.5, algo=PartialRemoveSpecieTransformation.ALGO_COMPLETE), 5
)
assert len(alt) == 2
Expand All @@ -66,36 +66,36 @@ def test_from_dict(self):
with open(f"{TEST_DIR}/transformations.json") as file:
dct = json.load(file)
dct["other_parameters"] = {"tags": ["test"]}
ts = TransformedStructure.from_dict(dct)
ts.other_parameters["author"] = "Will"
ts.append_transformation(SubstitutionTransformation({"Fe": "Mn"}))
assert ts.final_structure.reduced_formula == "MnPO4"
assert ts.other_parameters == {"author": "Will", "tags": ["test"]}
t_struct = TransformedStructure.from_dict(dct)
t_struct.other_parameters["author"] = "Will"
t_struct.append_transformation(SubstitutionTransformation({"Fe": "Mn"}))
assert t_struct.final_structure.reduced_formula == "MnPO4"
assert t_struct.other_parameters == {"author": "Will", "tags": ["test"]}

def test_undo_and_redo_last_change(self):
trafos = [
SubstitutionTransformation({"Li": "Na"}),
SubstitutionTransformation({"Fe": "Mn"}),
]
ts = TransformedStructure(self.structure, trafos)
assert ts.final_structure.reduced_formula == "NaMnPO4"
ts.undo_last_change()
assert ts.final_structure.reduced_formula == "NaFePO4"
ts.undo_last_change()
assert ts.final_structure.reduced_formula == "LiFePO4"
t_struct = TransformedStructure(self.structure, trafos)
assert t_struct.final_structure.reduced_formula == "NaMnPO4"
t_struct.undo_last_change()
assert t_struct.final_structure.reduced_formula == "NaFePO4"
t_struct.undo_last_change()
assert t_struct.final_structure.reduced_formula == "LiFePO4"
with pytest.raises(IndexError, match="No more changes to undo"):
ts.undo_last_change()
ts.redo_next_change()
assert ts.final_structure.reduced_formula == "NaFePO4"
ts.redo_next_change()
assert ts.final_structure.reduced_formula == "NaMnPO4"
t_struct.undo_last_change()
t_struct.redo_next_change()
assert t_struct.final_structure.reduced_formula == "NaFePO4"
t_struct.redo_next_change()
assert t_struct.final_structure.reduced_formula == "NaMnPO4"
with pytest.raises(IndexError, match="No more changes to redo"):
ts.redo_next_change()
t_struct.redo_next_change()
# Make sure that this works with filters.
f3 = ContainsSpecieFilter(["O2-"], strict_compare=True, AND=False)
ts.append_filter(f3)
ts.undo_last_change()
ts.redo_next_change()
t_struct.append_filter(f3)
t_struct.undo_last_change()
t_struct.redo_next_change()

def test_set_parameter(self):
trans = self.trans.set_parameter("author", "will")
Expand All @@ -113,19 +113,19 @@ def test_as_dict(self):
def test_snl(self):
self.trans.set_parameter("author", "will")
with pytest.warns(UserWarning) as warns:
snl = self.trans.to_snl([("will", "will@test.com")])
struct_nl = self.trans.to_snl([("will", "will@test.com")])

assert len(warns) == 1, "Warning not raised on type conversion with other_parameters"
assert len(warns) >= 1, f"Warning not raised on type conversion with other_parameters {len(warns)=}"
assert (
str(warns[0].message)
== "Data in TransformedStructure.other_parameters discarded during type conversion to SNL"
)

ts = TransformedStructure.from_snl(snl)
assert ts.history[-1]["@class"] == "SubstitutionTransformation"
t_struct = TransformedStructure.from_snl(struct_nl)
assert t_struct.history[-1]["@class"] == "SubstitutionTransformation"

hist = ("testname", "testURL", {"test": "testing"})
snl = StructureNL(ts.final_structure, [("will", "will@test.com")], history=[hist])
snl = TransformedStructure.from_snl(snl).to_snl([("notwill", "notwill@test.com")])
assert snl.history == [hist]
assert snl.authors == [("notwill", "notwill@test.com")]
struct_nl = StructureNL(t_struct.final_structure, [("will", "will@test.com")], history=[hist])
t_struct = TransformedStructure.from_snl(struct_nl).to_snl([("notwill", "notwill@test.com")])
assert t_struct.history == [hist]
assert t_struct.authors == [("notwill", "notwill@test.com")]
Loading

0 comments on commit bc5323a

Please sign in to comment.