-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
259 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
"""AnnotateCrystals class.""" | ||
|
||
from typing import Dict, List, Optional, TypedDict | ||
|
||
import atomrdf as ardf | ||
|
||
from atomid.annotate import AnnotateCrystal | ||
|
||
|
||
class CrystalEntry(TypedDict): | ||
"""Type definition for a crystal entry.""" | ||
|
||
data_file: str | ||
system: AnnotateCrystal | ||
|
||
|
||
class AnnotateCrystals: | ||
"""Annotate multiple crystal structures.""" | ||
|
||
def __init__(self, data_files: Optional[List[str]] = None): | ||
self._kg = ardf.KnowledgeGraph() | ||
self.crystals_dict: Dict[int, CrystalEntry] = {} | ||
|
||
if data_files: | ||
for idx, data_file in enumerate(data_files, start=1): | ||
self.crystals_dict[idx] = { | ||
"data_file": data_file, | ||
"system": AnnotateCrystal(data_file, kg=self._kg), | ||
} | ||
|
||
@property | ||
def num_samples(self) -> int: | ||
"""Return the number of samples.""" | ||
return len(self.crystals_dict) | ||
|
||
@property | ||
def kg(self) -> ardf.KnowledgeGraph: | ||
"""Return the knowledge graph.""" | ||
return self._kg | ||
|
||
def add_data_file(self, data_file: str) -> None: | ||
"""Add a data file to the list of samples. | ||
Parameters | ||
---------- | ||
data_file : str | ||
The path to the data file. | ||
Returns | ||
------- | ||
None | ||
""" | ||
idx = self.num_samples + 1 | ||
self.crystals_dict[idx] = { | ||
"data_file": data_file, | ||
"system": AnnotateCrystal(data_file, kg=self._kg), | ||
} | ||
|
||
def get_crystal(self, idx: int) -> AnnotateCrystal: | ||
"""Return the AnnotateCrystal object for a given sample. | ||
Parameters | ||
---------- | ||
idx : int | ||
The index of the sample. | ||
Returns | ||
------- | ||
AnnotateCrystal | ||
The AnnotateCrystal object for the sample. | ||
""" | ||
return self.crystals_dict[idx]["system"] | ||
|
||
def get_data_file(self, idx: int) -> str: | ||
"""Return the data file for a given sample. | ||
Parameters | ||
---------- | ||
idx : int | ||
The index of the sample. | ||
Returns | ||
------- | ||
str | ||
The path to the data file. | ||
""" | ||
return self.crystals_dict[idx]["data_file"] | ||
|
||
def get_sample(self, idx: int) -> CrystalEntry: | ||
"""Return the dictionary for a given sample. | ||
Parameters | ||
---------- | ||
idx : int | ||
The index of the sample. | ||
Returns | ||
------- | ||
dict | ||
The dictionary for the sample. | ||
""" | ||
return self.crystals_dict[idx] | ||
|
||
def annotate_all_crystal_structures(self) -> None: | ||
"""Annotate the crystal structures for all samples. | ||
This method identifies the crystal structure for each sample and annotates it. | ||
Returns | ||
------- | ||
None | ||
""" | ||
for idx in self.crystals_dict: | ||
self.crystals_dict[idx]["system"].identify_crystal_structure() | ||
if self.crystals_dict[idx]["system"].crystal_type != "other": | ||
self.crystals_dict[idx]["system"].annotate_crystal_structure() | ||
else: | ||
print( | ||
f"Crystal structure not identified for sample {idx}" | ||
f"{self.crystals_dict[idx]['data_file']}" | ||
) | ||
print( | ||
"Manually annotate the crystal structure using the " | ||
"set_crystal_structure method." | ||
) | ||
|
||
def annotate_crystal_structure(self, idx: int, log: bool = False) -> None: | ||
"""Annotate the crystal structure for a given sample. | ||
This method identifies the crystal structure for a given sample and annotates it. | ||
Parameters | ||
---------- | ||
idx : int | ||
The index of the sample. | ||
log : bool, optional | ||
Whether to log the output. Defaults to False. | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.crystals_dict[idx]["system"].identify_crystal_structure(log=log) | ||
if self.crystals_dict[idx]["system"].crystal_type is not None: | ||
self.crystals_dict[idx]["system"].annotate_crystal_structure() | ||
else: | ||
print( | ||
f"Crystal structure not identified for sample {idx}" | ||
f"{self.crystals_dict[idx]['data_file']}" | ||
) | ||
print( | ||
"Manually annotate the crystal structure using the " | ||
"set_crystal_structure method." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Tests for the AnnotateCrystals class.""" | ||
|
||
from typing import List | ||
|
||
import atomrdf as ardf | ||
import pytest | ||
from atomid.annotate import AnnotateCrystal | ||
from atomid.annotate_crystals import AnnotateCrystals | ||
|
||
"""Tests for the AnnotateCrystals class.""" | ||
|
||
|
||
@pytest.fixture | ||
def annotate_crystals() -> AnnotateCrystals: | ||
data_files: List[str] = [ | ||
"tests/data/bcc/Fe/defect/interstitial/initial/Fe_interstitial.poscar", | ||
"tests/data/fcc/Al/no_defect/initial/Al.poscar", | ||
] | ||
return AnnotateCrystals(data_files=data_files) | ||
|
||
|
||
class TestAnnotateCrystals: | ||
"""Tests for the AnnotateCrystals class.""" | ||
|
||
def test_num_samples(self, annotate_crystals: AnnotateCrystals) -> None: | ||
assert annotate_crystals.num_samples == 2 | ||
|
||
def test_kg(self, annotate_crystals: AnnotateCrystals) -> None: | ||
assert isinstance(annotate_crystals.kg, ardf.KnowledgeGraph) | ||
|
||
def test_add_data_file(self, annotate_crystals: AnnotateCrystals) -> None: | ||
new_file: str = "tests/data/hcp/Mg/no_defect/initial/Mg.poscar" | ||
annotate_crystals.add_data_file(new_file) | ||
assert annotate_crystals.num_samples == 3 | ||
assert annotate_crystals.get_data_file(3) == new_file | ||
|
||
def test_get_crystal(self, annotate_crystals: AnnotateCrystals) -> None: | ||
crystal: AnnotateCrystal = annotate_crystals.get_crystal(1) | ||
assert isinstance(crystal, AnnotateCrystal) | ||
|
||
def test_get_data_file(self, annotate_crystals: AnnotateCrystals) -> None: | ||
data_file: str = annotate_crystals.get_data_file(1) | ||
assert ( | ||
data_file | ||
== "tests/data/bcc/Fe/defect/interstitial/initial/Fe_interstitial.poscar" | ||
) | ||
|
||
def test_get_sample(self, annotate_crystals: AnnotateCrystals) -> None: | ||
sample = annotate_crystals.get_sample(1) | ||
assert ( | ||
sample["data_file"] | ||
== "tests/data/bcc/Fe/defect/interstitial/initial/Fe_interstitial.poscar" | ||
) | ||
assert isinstance(sample["system"], AnnotateCrystal) | ||
|
||
def test_annotate_all_crystal_structures( | ||
self, annotate_crystals: AnnotateCrystals | ||
) -> None: | ||
annotate_crystals.annotate_all_crystal_structures() | ||
for idx in annotate_crystals.crystals_dict: | ||
system: AnnotateCrystal = annotate_crystals.crystals_dict[idx]["system"] | ||
assert system.crystal_type is not None | ||
|
||
def test_annotate_crystal_structure( | ||
self, annotate_crystals: AnnotateCrystals | ||
) -> None: | ||
idx: int = 1 | ||
annotate_crystals.annotate_crystal_structure(idx) | ||
system: AnnotateCrystal = annotate_crystals.crystals_dict[idx]["system"] | ||
assert system.crystal_type is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters