Skip to content

Commit

Permalink
Read multiple files (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
NinadBhat authored Jun 26, 2024
1 parent 645f202 commit d21ce35
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 34 deletions.
49 changes: 26 additions & 23 deletions src/atomid/annotate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Annotate crystal class."""

from typing import Optional
from typing import Any, Optional

import atomrdf as ardf
from ase.io import read as ase_read
Expand All @@ -26,7 +26,8 @@ def __init__(
self,
data_file: Optional[str] = None,
format: Optional[str] = None,
**kwargs: dict[str, str],
kg: Optional[ardf.KnowledgeGraph] = None,
**kwargs: Any,
) -> None:
"""Initialize the AnnotateCrystal object.
Expand All @@ -41,11 +42,19 @@ def __init__(
-------
None
"""
if data_file is not None and format is not None:
self.read_crystal_structure_file(data_file, format, **kwargs)
if kg is None:
self.kg = ardf.KnowledgeGraph()
else:
self.kg = kg

if data_file is not None:
if format is not None:
self.read_crystal_structure_file(data_file, format, **kwargs)
else:
self.read_crystal_structure_file(data_file, **kwargs)

def read_crystal_structure_file(
self, data_file: str, format: str, **kwargs: dict[str, str]
self, data_file: str, format: Optional[str] = None, **kwargs: dict[str, str]
) -> None:
"""Read the crystal structure file.
Expand All @@ -61,17 +70,9 @@ def read_crystal_structure_file(
None
"""
self.ase_crystal = ase_read(data_file, format=format, **kwargs)
kg = ardf.KnowledgeGraph()

crystal_structure = ardf.System.read.file(
filename=self.ase_crystal, format="ase", graph=kg
)

self.ovito_pipeline = import_file(data_file)

self.kg = kg
self.system = crystal_structure

def validate_parameters_for_crystal_annotation(self) -> None:
if hasattr(self, "lattice_constant") is False or self.lattice_constant is None:
self._raise_error(
Expand Down Expand Up @@ -125,7 +126,7 @@ def get_polyhedral_template_matching_data(self) -> DataCollection:
data = self.ovito_pipeline.compute()
return data

def identify_crystal_structure(self) -> None:
def identify_crystal_structure(self, log: bool = True) -> None:
"""Identify and annotate the crystal structure.
This method identifies the crystal structure using polyhedral template matching
Expand All @@ -148,8 +149,9 @@ def identify_crystal_structure(self) -> None:
self.lattice_constant = find_lattice_parameter(
interatomic_distance, structure_type_atoms, int(structure_id)
)
print(f"\033[92mCrystal structure: {crystal_type}\033[0m")
print(f"\033[92mLattice constant: {self.lattice_constant}\033[0m")
if log:
print(f"\033[92mCrystal structure: {crystal_type}\033[0m")
print(f"\033[92mLattice constant: {self.lattice_constant}\033[0m")
else:
# Warn user that the crystal structure could not be identified
# and set the lattice constant can not be determined
Expand Down Expand Up @@ -193,8 +195,6 @@ def annotate_crystal_structure(self) -> None:
None
"""
self.validate_parameters_for_crystal_annotation()

self.kg = ardf.KnowledgeGraph()
self.system = ardf.System.read.file(
self.ase_crystal,
format="ase",
Expand All @@ -208,6 +208,7 @@ def identify_point_defects(
reference_data_file: str,
ref_format: str,
method: Optional[str] = None,
log: bool = True,
**kwargs: dict[str, str],
) -> None:
"""Identify defects in the crystal structure using the reference data file.
Expand Down Expand Up @@ -243,13 +244,15 @@ def identify_point_defects(

self.defects.update(defects)
# Print identified defects
print("\033[92mIdentified defects:\033[0m")
if log:
print("\033[92mIdentified defects:\033[0m")
for defect, defect_info in defects.items():
if defect_info["count"] != 0:
print(
f"\033[92m{defect}:\033[0m Count: {defect_info['count']} "
f"Concentration: {defect_info['concentration']:.2f}"
)
if log:
print(
f"\033[92m{defect}:\033[0m Count: {defect_info['count']} "
f"Concentration: {defect_info['concentration']:.2f}"
)

def add_vacancy_information(self, concentration: float, number: int) -> None:
"""Add vacancy information to the system.
Expand Down
154 changes: 154 additions & 0 deletions src/atomid/annotate_crystals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""AnnotateCrystals class."""

from typing import Dict, List, Optional, TypedDict

import atomrdf as ardf

from atomid.annotate import AnnotateCrystal


class CrystalEntry(TypedDict):
"""Type definition for a crystal entry."""

data_file: str
system: AnnotateCrystal


class AnnotateCrystals:
"""Annotate multiple crystal structures."""

def __init__(self, data_files: Optional[List[str]] = None):
self._kg = ardf.KnowledgeGraph()
self.crystals_dict: Dict[int, CrystalEntry] = {}

if data_files:
for idx, data_file in enumerate(data_files, start=1):
self.crystals_dict[idx] = {
"data_file": data_file,
"system": AnnotateCrystal(data_file, kg=self._kg),
}

@property
def num_samples(self) -> int:
"""Return the number of samples."""
return len(self.crystals_dict)

@property
def kg(self) -> ardf.KnowledgeGraph:
"""Return the knowledge graph."""
return self._kg

def add_data_file(self, data_file: str) -> None:
"""Add a data file to the list of samples.
Parameters
----------
data_file : str
The path to the data file.
Returns
-------
None
"""
idx = self.num_samples + 1
self.crystals_dict[idx] = {
"data_file": data_file,
"system": AnnotateCrystal(data_file, kg=self._kg),
}

def get_crystal(self, idx: int) -> AnnotateCrystal:
"""Return the AnnotateCrystal object for a given sample.
Parameters
----------
idx : int
The index of the sample.
Returns
-------
AnnotateCrystal
The AnnotateCrystal object for the sample.
"""
return self.crystals_dict[idx]["system"]

def get_data_file(self, idx: int) -> str:
"""Return the data file for a given sample.
Parameters
----------
idx : int
The index of the sample.
Returns
-------
str
The path to the data file.
"""
return self.crystals_dict[idx]["data_file"]

def get_sample(self, idx: int) -> CrystalEntry:
"""Return the dictionary for a given sample.
Parameters
----------
idx : int
The index of the sample.
Returns
-------
dict
The dictionary for the sample.
"""
return self.crystals_dict[idx]

def annotate_all_crystal_structures(self) -> None:
"""Annotate the crystal structures for all samples.
This method identifies the crystal structure for each sample and annotates it.
Returns
-------
None
"""
for idx in self.crystals_dict:
self.crystals_dict[idx]["system"].identify_crystal_structure()
if self.crystals_dict[idx]["system"].crystal_type != "other":
self.crystals_dict[idx]["system"].annotate_crystal_structure()
else:
print(
f"Crystal structure not identified for sample {idx}"
f"{self.crystals_dict[idx]['data_file']}"
)
print(
"Manually annotate the crystal structure using the "
"set_crystal_structure method."
)

def annotate_crystal_structure(self, idx: int, log: bool = False) -> None:
"""Annotate the crystal structure for a given sample.
This method identifies the crystal structure for a given sample and annotates it.
Parameters
----------
idx : int
The index of the sample.
log : bool, optional
Whether to log the output. Defaults to False.
Returns
-------
None
"""
self.crystals_dict[idx]["system"].identify_crystal_structure(log=log)
if self.crystals_dict[idx]["system"].crystal_type is not None:
self.crystals_dict[idx]["system"].annotate_crystal_structure()
else:
print(
f"Crystal structure not identified for sample {idx}"
f"{self.crystals_dict[idx]['data_file']}"
)
print(
"Manually annotate the crystal structure using the "
"set_crystal_structure method."
)
70 changes: 70 additions & 0 deletions tests/test_annotate_crystals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Tests for the AnnotateCrystals class."""

from typing import List

import atomrdf as ardf
import pytest
from atomid.annotate import AnnotateCrystal
from atomid.annotate_crystals import AnnotateCrystals

"""Tests for the AnnotateCrystals class."""


@pytest.fixture
def annotate_crystals() -> AnnotateCrystals:
data_files: List[str] = [
"tests/data/bcc/Fe/defect/interstitial/initial/Fe_interstitial.poscar",
"tests/data/fcc/Al/no_defect/initial/Al.poscar",
]
return AnnotateCrystals(data_files=data_files)


class TestAnnotateCrystals:
"""Tests for the AnnotateCrystals class."""

def test_num_samples(self, annotate_crystals: AnnotateCrystals) -> None:
assert annotate_crystals.num_samples == 2

def test_kg(self, annotate_crystals: AnnotateCrystals) -> None:
assert isinstance(annotate_crystals.kg, ardf.KnowledgeGraph)

def test_add_data_file(self, annotate_crystals: AnnotateCrystals) -> None:
new_file: str = "tests/data/hcp/Mg/no_defect/initial/Mg.poscar"
annotate_crystals.add_data_file(new_file)
assert annotate_crystals.num_samples == 3
assert annotate_crystals.get_data_file(3) == new_file

def test_get_crystal(self, annotate_crystals: AnnotateCrystals) -> None:
crystal: AnnotateCrystal = annotate_crystals.get_crystal(1)
assert isinstance(crystal, AnnotateCrystal)

def test_get_data_file(self, annotate_crystals: AnnotateCrystals) -> None:
data_file: str = annotate_crystals.get_data_file(1)
assert (
data_file
== "tests/data/bcc/Fe/defect/interstitial/initial/Fe_interstitial.poscar"
)

def test_get_sample(self, annotate_crystals: AnnotateCrystals) -> None:
sample = annotate_crystals.get_sample(1)
assert (
sample["data_file"]
== "tests/data/bcc/Fe/defect/interstitial/initial/Fe_interstitial.poscar"
)
assert isinstance(sample["system"], AnnotateCrystal)

def test_annotate_all_crystal_structures(
self, annotate_crystals: AnnotateCrystals
) -> None:
annotate_crystals.annotate_all_crystal_structures()
for idx in annotate_crystals.crystals_dict:
system: AnnotateCrystal = annotate_crystals.crystals_dict[idx]["system"]
assert system.crystal_type is not None

def test_annotate_crystal_structure(
self, annotate_crystals: AnnotateCrystals
) -> None:
idx: int = 1
annotate_crystals.annotate_crystal_structure(idx)
system: AnnotateCrystal = annotate_crystals.crystals_dict[idx]["system"]
assert system.crystal_type is not None
20 changes: 9 additions & 11 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ def test_read_crystal_structure_file(
self, sample_crystal_file: str, reference_crystal_file: str
) -> None:
annotate_crystal = AnnotateCrystal()
annotate_crystal.read_crystal_structure_file(
sample_crystal_file, format="vasp"
) # Adjust format if needed
annotate_crystal.read_crystal_structure_file(sample_crystal_file, format="vasp")

assert annotate_crystal.ase_crystal is not None
assert annotate_crystal.system is not None
assert annotate_crystal.kg is not None

@pytest.mark.parametrize(
Expand All @@ -53,9 +50,7 @@ def test_annotate_crystal_structure(
self, sample_crystal_file: str, reference_crystal_file: str
) -> None:
annotate_crystal = AnnotateCrystal()
annotate_crystal.read_crystal_structure_file(
sample_crystal_file, format="vasp"
) # Adjust format if needed
annotate_crystal.read_crystal_structure_file(sample_crystal_file, format="vasp")
annotate_crystal.identify_crystal_structure()
annotate_crystal.annotate_crystal_structure()

Expand All @@ -68,12 +63,13 @@ def test_identify_defects(
self, sample_crystal_file: str, reference_crystal_file: str
) -> None:
annotate_crystal = AnnotateCrystal()
annotate_crystal.read_crystal_structure_file(
sample_crystal_file, format="vasp"
) # Adjust format if needed
annotate_crystal.read_crystal_structure_file(sample_crystal_file, format="vasp")
annotate_crystal.identify_crystal_structure()
annotate_crystal.annotate_crystal_structure()

annotate_crystal.identify_point_defects(
reference_crystal_file, ref_format="vasp"
) # Adjust format if needed
)

assert isinstance(annotate_crystal.defects, dict)
assert "vacancies" in annotate_crystal.defects
Expand All @@ -89,6 +85,8 @@ def test_write_defects(
"""Test writing the defects to a file."""
annotate_crystal = AnnotateCrystal()
annotate_crystal.read_crystal_structure_file(sample_crystal_file, format="vasp")
annotate_crystal.identify_crystal_structure()
annotate_crystal.annotate_crystal_structure()
annotate_crystal.identify_point_defects(
reference_crystal_file, ref_format="vasp"
)
Expand Down

0 comments on commit d21ce35

Please sign in to comment.