From 2c76739ebf88be2d65b26155ce688199a6275a81 Mon Sep 17 00:00:00 2001 From: CompRhys Date: Wed, 24 Jul 2024 15:00:01 -0400 Subject: [PATCH] fea: get_formula_from_protostructure_label --- aviary/wren/utils.py | 39 ++++++++++++++++++++++++--------------- pyproject.toml | 1 + tests/test_wyckoff_ops.py | 24 +++++++++++------------- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 4baae435..ed832cb9 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -206,9 +206,7 @@ def get_protostructure_label_from_spg_analyzer( equivalent_wyckoff_labels = [ # tuple of (wp multiplicity, element, wyckoff letter) (len(s), s[0].species_string, wyk_letter.translate(remove_digits)) - for s, wyk_letter in zip( - sym_struct.equivalent_sites, sym_struct.wyckoff_symbols - ) + for s, wyk_letter in zip(sym_struct.equivalent_sites, sym_struct.wyckoff_symbols) ] # Pre-sort by element and wyckoff letter to ensure continuous groups in groupby equivalent_wyckoff_labels = sorted( @@ -286,9 +284,7 @@ def get_protostructure_label_from_spglib( """ attempt_to_recover = False try: - spg_analyzer = SpacegroupAnalyzer( - struct, symprec=init_symprec, angle_tolerance=5 - ) + spg_analyzer = SpacegroupAnalyzer(struct, symprec=init_symprec, angle_tolerance=5) try: aflow_label_with_chemsys = get_protostructure_label_from_spg_analyzer( spg_analyzer, raise_errors @@ -429,6 +425,25 @@ def get_anonymous_formula_from_prototype_formula(prototype_formula: str) -> str: ) +def get_formula_from_protostructure_label(protostructure_label: str) -> str: + """Get a formula from a protostructure label.""" + aflow_label, chemsys = protostructure_label.split(":") + prototype_formula = aflow_label.split("_")[0] + prototype_formula = re.sub( + RE_ELEMENT_NO_SUFFIX, RE_SUBST_ONE_SUFFIX, prototype_formula + ) + anom_list = split_alpha_numeric(prototype_formula) + + return "".join( + [ + f"{el}{num}" if num != 1 else el + for el, num in zip( + chemsys.split("-"), map(int, anom_list["numeric"]), strict=True + ) + ] + ) + + def count_distinct_wyckoff_letters(protostructure_label: str) -> int: """Count number of distinct Wyckoff letters in protostructure_label. @@ -507,9 +522,7 @@ def count_crystal_sites(protostructure_label: str) -> int: return _count_from_dict(element_wyckoffs, wyckoff_multiplicity_dict, spg_num) -def _count_from_dict( - element_wyckoffs: list[str], lookup_dict: dict, spg_num: str -) -> int: +def _count_from_dict(element_wyckoffs: list[str], lookup_dict: dict, spg_num: str) -> int: """Count number of sites from protostructure_label.""" n_params = 0 @@ -541,9 +554,7 @@ def get_prototype_from_protostructure(protostructure_label: str) -> str: str: Canonicalized AFLOW-style prototype label """ aflow_label, _ = protostructure_label.split(":") - prototype_formula, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split( - "_" - ) + prototype_formula, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split("_") anonymous_formula = get_anonymous_formula_from_prototype_formula(prototype_formula) counts = [ @@ -657,9 +668,7 @@ def get_protostructures_from_aflow_label_and_composition( list[str]: List of possible protostructure labels that can be generated from combinations of the input aflow_label and composition. """ - anonymous_formula, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split( - "_" - ) + anonymous_formula, pearson_symbol, spg_num, *element_wyckoffs = aflow_label.split("_") ele_amt_dict = composition.get_el_amt_dict() proto_formula = get_prototype_formula_from_composition(composition) diff --git a/pyproject.toml b/pyproject.toml index 5a589cd9..a930018d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ addopts = "-p no:warnings" no_implicit_optional = false [tool.ruff] +line-length = 90 target-version = "py39" extend-include = ["*.ipynb"] lint.select = [ diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index aa86b8a7..459240b4 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -14,6 +14,7 @@ count_distinct_wyckoff_letters, count_wyckoff_positions, get_anonymous_formula_from_prototype_formula, + get_formula_from_protostructure_label, get_protostructure_label_from_aflow, get_protostructure_label_from_spg_analyzer, get_protostructure_label_from_spglib, @@ -146,10 +147,7 @@ def test_get_prototype_from_protostructure(protostructure_label, expected): element_wyckoff = "_".join(wyckoffs) isopointal_element_wyckoffs = list( - { - element_wyckoff.translate(str.maketrans(trans)) - for trans in relab_dict[spg_num] - } + {element_wyckoff.translate(str.maketrans(trans)) for trans in relab_dict[spg_num]} ) protostructure_labels = [ @@ -157,11 +155,6 @@ def test_get_prototype_from_protostructure(protostructure_label, expected): for element_wyckoff in isopointal_element_wyckoffs ] - print(protostructure_label) - print(protostructure_labels) - print(get_prototype_from_protostructure(protostructure_label)) - print(expected) - assert all( get_prototype_from_protostructure(protostructure_label) == expected for protostructure_label in protostructure_labels @@ -208,10 +201,7 @@ def test_get_protostructures_from_aflow_label_and_composition( ( {"a": 1, "b": 1, "c": 1}, {"x": 1, "y": 1, "z": 1}, - [ - dict(zip(["a", "b", "c"], perm)) - for perm in permutations(["x", "y", "z"]) - ], + [dict(zip(["a", "b", "c"], perm)) for perm in permutations(["x", "y", "z"])], ), # Test case 3: No valid translations (different values) ({"a": 1, "b": 2}, {"x": 1, "y": 3}, []), @@ -253,6 +243,14 @@ def test_get_prototype_formula_from_composition(composition: str, expected: str) assert get_prototype_formula_from_composition(Composition(composition)) == expected +@pytest.mark.parametrize( + "protostructure_label, expected", + [("AB3C_oP20_62_c_cd_a:Ni-O-Yb", "NiO3Yb")], +) +def test_get_formula_from_protostructure_label(protostructure_label: str, expected: str): + assert get_formula_from_protostructure_label(protostructure_label) == expected + + @pytest.mark.parametrize( "anonymous_formula, prototype_formula", [("AB", "AB"), ("A2B", "AB2"), ("A3B2CD4", "AB2C3D4")],