Skip to content

Commit

Permalink
improved dataset: properties, io, and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 26, 2024
1 parent 738d6f4 commit d118aee
Show file tree
Hide file tree
Showing 9 changed files with 517 additions and 44 deletions.
1 change: 1 addition & 0 deletions deps/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ spglib >= 1.16.4;python_version=='3.12' # minimum working
tabulate >= 0.8.8;python_version=='3.10' # minimum working
tabulate >= 0.8.8;python_version=='3.11' # minimum working
tabulate >= 0.8.8;python_version=='3.12' # minimum working
tqdm
4 changes: 4 additions & 0 deletions ramannoodle/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class InvalidFileException(Exception):
"""File cannot be read, likely due to due to invalid or unexpected format."""


class IncompatibleStructureException(Exception):
"""Supplied file is incompatible."""


class InvalidDOFException(Exception):
"""A supplied degree of freedom is invalid."""

Expand Down
35 changes: 35 additions & 0 deletions ramannoodle/io/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ramannoodle.structure.reference import ReferenceStructure
import ramannoodle.io.vasp as vasp_io
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

# These map between file formats and appropriate IO functions.
_PHONON_READERS = {
Expand All @@ -36,6 +37,10 @@
"outcar": vasp_io.outcar.read_structure_and_polarizability,
"vasprun.xml": vasp_io.vasprun.read_structure_and_polarizability,
}
_POLARIZABILITY_DATASET_READERS = {
"outcar": vasp_io.outcar.read_polarizability_dataset,
"vasprun.xml": vasp_io.vasprun.read_polarizability_dataset,
}
_POSITION_READERS = {
"poscar": vasp_io.poscar.read_positions,
"outcar": vasp_io.outcar.read_positions,
Expand Down Expand Up @@ -181,6 +186,36 @@ def read_structure_and_polarizability(
raise ValueError(f"unsupported format: {file_format}") from exc


def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
file_format: str,
) -> PolarizabilityDataset:
"""Read polarizability dataset from files.
Parameters
----------
filepath
file_format
| Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`)
Returns
-------
:
Raises
------
FileNotFoundError
InvalidFileException
File has an unexpected format.
IncompatibleFileException
File is incompatible with the dataset.
"""
try:
return _POLARIZABILITY_DATASET_READERS[file_format](filepaths)
except KeyError as exc:
raise ValueError(f"unsupported format: {file_format}") from exc


def read_positions(
filepath: str | Path,
file_format: str,
Expand Down
68 changes: 67 additions & 1 deletion ramannoodle/io/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"""Universal IO utility functions."""

from typing import TextIO
from typing import TextIO, Callable
from pathlib import Path

import numpy as np
from numpy.typing import NDArray
from tqdm import tqdm

from ramannoodle.exceptions import (
NoMatchingLineFoundException,
verify_ndarray_shape,
verify_positions,
verify_list_len,
IncompatibleStructureException,
)
from ramannoodle.globals import ATOM_SYMBOLS
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset


def _skip_file_until_line_contains(file: TextIO, content: str) -> str:
Expand Down Expand Up @@ -84,3 +87,66 @@ def verify_trajectory(
verify_ndarray_shape("positions_ts", positions_ts, (None, len(atomic_numbers), 3))
if (0 > positions_ts).any() or (positions_ts > 1.0).any():
raise ValueError("positions_ts has coordinates that are not between 0 and 1")


def _read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
read_structure_and_polarizability_fn: Callable[
[str | Path],
tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]],
],
) -> PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
----------
filepath
read_structure_and_polarizability_fn
Returns
-------
:
Raises
------
FileNotFoundError
InvalidFileException
File has an unexpected format.
IncompatibleFileException
File is incompatible with the dataset.
"""
filepaths = pathify_as_list(filepaths)

lattices: list[NDArray[np.float64]] = []
atomic_numbers_list: list[list[int]] = []
positions_list: list[NDArray[np.float64]] = []
polarizabilities: list[NDArray[np.float64]] = []
for file_index, filepath in tqdm(list(enumerate(filepaths)), unit="files"):
lattice, atomic_numbers, positions, polarizability = (
read_structure_and_polarizability_fn(filepath)
)
if file_index != 0:
if not np.isclose(lattices[0], lattice, atol=1e-5).all():
raise IncompatibleStructureException(
f"incompatible lattice: {filepath}"
)
if atomic_numbers_list[0] != atomic_numbers:
raise IncompatibleStructureException(
f"incompatible atomic numbers: {filepath}"
)
if positions_list[0].shape != positions.shape: # check, just to be safe
raise IncompatibleStructureException(
f"incompatible atomic positions: {filepath}"
)
lattices.append(lattice)
atomic_numbers_list.append(atomic_numbers)
positions_list.append(positions)
polarizabilities.append(polarizability)

return PolarizabilityDataset(
np.array(lattices),
atomic_numbers_list,
np.array(positions_list),
np.array(polarizabilities),
scale_mode="standard",
)
32 changes: 30 additions & 2 deletions ramannoodle/io/vasp/outcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import numpy as np
from numpy.typing import NDArray


from ramannoodle.io.io_utils import _skip_file_until_line_contains, pathify
from ramannoodle.io.io_utils import (
_skip_file_until_line_contains,
pathify,
_read_polarizability_dataset,
)
from ramannoodle.exceptions import InvalidFileException, NoMatchingLineFoundException
from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS
from ramannoodle.exceptions import get_type_error
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset


# Utilities for OUTCAR. Warning: some of these functions partially read files.
Expand Down Expand Up @@ -394,6 +398,30 @@ def read_structure_and_polarizability(
return lattice, atomic_numbers, positions, polarizability


def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
----------
filepaths
Returns
-------
:
Raises
------
FileNotFoundError
InvalidFileException
File has an unexpected format.
IncompatibleFileException
File is incompatible with the dataset.
"""
return _read_polarizability_dataset(filepaths, read_structure_and_polarizability)


def read_ref_structure(filepath: str | Path) -> ReferenceStructure:
"""Read reference structure from a VASP OUTCAR file.
Expand Down
27 changes: 26 additions & 1 deletion ramannoodle/io/vasp/vasprun.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import numpy as np
from numpy.typing import NDArray

from ramannoodle.io.io_utils import pathify
from ramannoodle.io.io_utils import pathify, _read_polarizability_dataset
from ramannoodle.exceptions import InvalidFileException
from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset


def _get_root_element(file: TextIO) -> Element:
Expand Down Expand Up @@ -192,6 +193,30 @@ def read_structure_and_polarizability(
return lattice, atomic_numbers, positions, polarizability


def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
----------
filepaths
Returns
-------
:
Raises
------
FileNotFoundError
InvalidFileException
File has an unexpected format.
IncompatibleFileException
File is incompatible with the dataset.
"""
return _read_polarizability_dataset(filepaths, read_structure_and_polarizability)


def read_positions(filepath: str | Path) -> NDArray[np.float64]:
"""Read fractional positions from a vasprun.xml file.
Expand Down
83 changes: 77 additions & 6 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Polarizability PyTorch dataset."""

import copy

import numpy as np
from numpy.typing import NDArray

Expand Down Expand Up @@ -117,19 +119,88 @@ def __init__( # pylint: disable=too-many-arguments
self._positions = torch.from_numpy(positions).type(default_type)
self._polarizabilities = torch.from_numpy(polarizabilities)

mean, stddev, scaled = _scale_and_flatten_polarizabilities(
_, _, scaled = _scale_and_flatten_polarizabilities(
self._polarizabilities, scale_mode=scale_mode
)
self._mean_polarizability = mean.type(default_type)
self._stddev_polarizability = stddev.type(default_type)
self._scaled_polarizabilities = scaled.type(default_type)

@property
def num_atoms(self) -> int:
"""Get number of atoms per sample."""
return self._positions.size(1)

@property
def num_samples(self) -> int:
"""Get number of samples."""
return self._positions.size(0)

@property
def atomic_numbers(self) -> Tensor:
"""Get (a copy of) atomic numbers.
Returns
-------
:
2D tensor with size [S,N] where S is the number of samples and N is the
number of atoms.
"""
return copy.copy(self._atomic_numbers)

@property
def positions(self) -> Tensor:
"""Get (a copy of) positions.
Returns
-------
:
3D tensor with size [S,N,3] where S is the number of samples and N is the
number of atoms.
"""
return self._positions.detach().clone()

@property
def polarizabilities(self) -> Tensor:
"""Get (a copy of) polarizabilities.
Returns
-------
:
3D tensor with size [S,3,3] where S is the number of samples.
"""
return self._polarizabilities.detach().clone()

@property
def scaled_polarizabilities(self) -> Tensor:
"""Get (a copy of) scaled polarizabilities.
Returns
-------
:
2D tensor with size [S,6] where S is the number of samples.
"""
return self._scaled_polarizabilities.detach().clone()

@property
def mean_polarizability(self) -> Tensor:
"""Get mean polarizability.
Return
------
:
2D tensor with size [3,3].
"""
return self._polarizabilities.mean(0, keepdim=True)

@property
def stddev_polarizability(self) -> Tensor:
"""Get standard deviation of polarizability."""
return self._polarizabilities.std(0, unbiased=False, keepdim=True)

def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None:
"""Standard-scale polarizabilities given a mean and standard deviation.
This method may be used to scale validation/test datasets according
This method may be used to scale validation or test datasets according
to the mean and standard deviation of the training set, as is best practice.
This method does **not** update ...
Parameters
----------
Expand Down Expand Up @@ -163,7 +234,7 @@ def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None:

def __len__(self) -> int:
"""Get number of samples."""
return len(self._positions)
return self.num_samples

def __getitem__(self, i: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Get lattice, atomic numbers, positions, and scaled polarizabilities."""
Expand Down
Loading

0 comments on commit d118aee

Please sign in to comment.