Skip to content

Commit

Permalink
moved PolarizabilityDataset to its own package
Browse files Browse the repository at this point in the history
this should avoid circular imports
  • Loading branch information
wolearyc committed Sep 26, 2024
1 parent 7aa554b commit baa8795
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 111 deletions.
1 change: 1 addition & 0 deletions ramannoodle/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Datasets."""
1 change: 1 addition & 0 deletions ramannoodle/dataset/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Torch datasets."""
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from torch import Tensor
from torch.utils.data import Dataset
import ramannoodle.pmodel.torch.utils as rn_torch_utils
import ramannoodle.dataset.torch.utils as rn_torch_utils
except ModuleNotFoundError as exc:
raise get_torch_missing_error() from exc

Expand Down Expand Up @@ -42,10 +42,6 @@ def _scale_and_flatten_polarizabilities(
#. polarizability vectors -- Tensor with size [S,6].
"""
rn_torch_utils.verify_tensor_size(
"polarizabilities", polarizabilities, [None, 3, 3]
)

mean = polarizabilities.mean(0, keepdim=True)
stddev = polarizabilities.std(0, unbiased=False, keepdim=True)
if scale_mode == "standard":
Expand Down
110 changes: 110 additions & 0 deletions ramannoodle/dataset/torch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Some torch utilities."""

from typing import Sequence

from ramannoodle.exceptions import (
get_type_error,
get_torch_missing_error,
)

try:
import torch
from torch import Tensor
except ModuleNotFoundError as exc:
raise get_torch_missing_error() from exc

# pylint complains about torch.norm
# pylint: disable=not-callable


def polarizability_vectors_to_tensors(polarizability_vectors: Tensor) -> Tensor:
"""Convert polarizability vectors to symmetric tensors.
Parameters
----------
polarizability_vectors
Tensor with size [S,6].
Returns
-------
:
Tensor with size [S,3,3].
"""
verify_tensor_size("polarizability_vectors", polarizability_vectors, (None, 6))
indices = torch.tensor(
[
[0, 3, 4],
[3, 1, 5],
[4, 5, 2],
]
)
return polarizability_vectors[:, indices]


def polarizability_tensors_to_vectors(polarizability_tensors: Tensor) -> Tensor:
"""Convert polarizability tensors to vectors.
Parameters
----------
polarizability_tensors
Tensor with size [S,3,3] where S is the number of samples.
Returns
-------
:
Tensor with size [S,6].
"""
verify_tensor_size("polarizability_tensors", polarizability_tensors, (None, 3, 3))
indices = torch.tensor([[0, 0], [1, 1], [2, 2], [0, 1], [0, 2], [1, 2]]).T
return polarizability_tensors[:, indices[0], indices[1]]


def _get_tensor_size_str(size: Sequence[int | None]) -> str:
"""Get a string representing a tensor size.
"_" indicates a dimension can be any size.
Parameters
----------
size
None indicates dimension can be any size.
"""
result = "["
for i in size:
if i is None:
result += "_,"
else:
result += f"{i},"
if len(size) == 1:
return result + "]"
return result[:-1] + "]"


def get_tensor_size_error(name: str, tensor: Tensor, desired_size: str) -> ValueError:
"""Get ValueError indicating a PyTorch Tensor has the wrong size."""
try:
shape_spec = f"{_get_tensor_size_str(tensor.size())} != {desired_size}"
except AttributeError as exc:
raise get_type_error("tensor", tensor, "Tensor") from exc
return ValueError(f"{name} has wrong size: {shape_spec}")


def verify_tensor_size(name: str, tensor: Tensor, size: Sequence[int | None]) -> None:
"""Verify a PyTorch Tensor's size.
:meta private: We should avoid calling this function whenever possible (EATF).
Parameters
----------
size
int elements will be checked, None elements will not be.
"""
try:
if len(size) != tensor.ndim:
raise get_tensor_size_error(name, tensor, _get_tensor_size_str(size))
for d1, d2 in zip(tensor.size(), size, strict=True):
if d2 is not None and d1 != d2:
raise get_tensor_size_error(name, tensor, _get_tensor_size_str(size))
except AttributeError as exc:
raise get_type_error(name, tensor, "Tensor") from exc
2 changes: 1 addition & 1 deletion ramannoodle/io/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

TORCH_PRESENT = True
try:
from ramannoodle.pmodel.torch.dataset import PolarizabilityDataset
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset
except UserError:
TORCH_PRESENT = False

Expand Down
2 changes: 1 addition & 1 deletion ramannoodle/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

TORCH_PRESENT = True
try:
from ramannoodle.pmodel.torch.dataset import PolarizabilityDataset
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset
except UserError:
TORCH_PRESENT = False

Expand Down
2 changes: 1 addition & 1 deletion ramannoodle/io/vasp/outcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ramannoodle.structure.reference import ReferenceStructure

try:
from ramannoodle.pmodel.torch.dataset import PolarizabilityDataset
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset
except UserError:
pass

Expand Down
2 changes: 1 addition & 1 deletion ramannoodle/io/vasp/vasprun.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ramannoodle.structure.reference import ReferenceStructure

try:
from ramannoodle.pmodel.torch.dataset import PolarizabilityDataset
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset
except UserError:
pass

Expand Down
11 changes: 7 additions & 4 deletions ramannoodle/pmodel/torch/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
from torch_geometric.nn.models.schnet import ShiftedSoftplus
from torch_geometric.utils import scatter
import ramannoodle.pmodel.torch.utils as rn_torch_utils
from ramannoodle.dataset.torch.utils import (
polarizability_vectors_to_tensors,
polarizability_tensors_to_vectors,
)

except (ModuleNotFoundError, UserError) as exc:
raise get_torch_missing_error() from exc

Expand Down Expand Up @@ -405,7 +410,7 @@ def _get_edge_polarizability_vectors(
+ a5 * torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]])
+ a6 * torch.tensor([[0, 0, 0], [0, 0, 0], [0, 0, 1]])
)
return rn_torch_utils.polarizability_tensors_to_vectors(edge_polarizability)
return polarizability_tensors_to_vectors(edge_polarizability)


class PotGNN(
Expand Down Expand Up @@ -701,9 +706,7 @@ def calc_polarizabilities(
atomic_numbers,
torch.tensor(positions_subbatch).type(default_type),
)
polarizability = rn_torch_utils.polarizability_vectors_to_tensors(
polarizability
)
polarizability = polarizability_vectors_to_tensors(polarizability)
end_index = start_index + subbatch_size
polarizabilities[start_index:end_index] = polarizability.detach()
progress_bar.update(end_index - start_index)
Expand Down
2 changes: 1 addition & 1 deletion ramannoodle/pmodel/torch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.optim.optimizer import Optimizer

from ramannoodle.pmodel.torch.gnn import PotGNN
from ramannoodle.pmodel.torch.dataset import PolarizabilityDataset
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset
except (ModuleNotFoundError, UserError) as exc:
raise get_torch_missing_error() from exc

Expand Down
96 changes: 1 addition & 95 deletions ramannoodle/pmodel/torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Utility functions for PyTorch models."""

from typing import Sequence, Generator
from typing import Generator

import numpy as np
from numpy.typing import NDArray

from ramannoodle.exceptions import (
get_type_error,
get_torch_missing_error,
verify_ndarray_shape,
)
Expand All @@ -22,99 +21,6 @@
# pylint: disable=not-callable


def polarizability_vectors_to_tensors(polarizability_vectors: Tensor) -> Tensor:
"""Convert polarizability vectors to symmetric tensors.
Parameters
----------
polarizability_vectors
Tensor with size [S,6].
Returns
-------
:
Tensor with size [S,3,3].
"""
verify_tensor_size("polarizability_vectors", polarizability_vectors, (None, 6))
indices = torch.tensor(
[
[0, 3, 4],
[3, 1, 5],
[4, 5, 2],
]
)
return polarizability_vectors[:, indices]


def polarizability_tensors_to_vectors(polarizability_tensors: Tensor) -> Tensor:
"""Convert polarizability tensors to vectors.
Parameters
----------
polarizability_tensors
Tensor with size [S,3,3] where S is the number of samples.
Returns
-------
:
Tensor with size [S,6].
"""
verify_tensor_size("polarizability_tensors", polarizability_tensors, (None, 3, 3))
indices = torch.tensor([[0, 0], [1, 1], [2, 2], [0, 1], [0, 2], [1, 2]]).T
return polarizability_tensors[:, indices[0], indices[1]]


def _get_tensor_size_str(size: Sequence[int | None]) -> str:
"""Get a string representing a tensor size.
"_" indicates a dimension can be any size.
Parameters
----------
size
None indicates dimension can be any size.
"""
result = "["
for i in size:
if i is None:
result += "_,"
else:
result += f"{i},"
if len(size) == 1:
return result + "]"
return result[:-1] + "]"


def get_tensor_size_error(name: str, tensor: Tensor, desired_size: str) -> ValueError:
"""Get ValueError indicating a PyTorch Tensor has the wrong size."""
try:
shape_spec = f"{_get_tensor_size_str(tensor.size())} != {desired_size}"
except AttributeError as exc:
raise get_type_error("tensor", tensor, "Tensor") from exc
return ValueError(f"{name} has wrong size: {shape_spec}")


def verify_tensor_size(name: str, tensor: Tensor, size: Sequence[int | None]) -> None:
"""Verify a PyTorch Tensor's size.
:meta private: We should avoid calling this function whenever possible (EATF).
Parameters
----------
size
int elements will be checked, None elements will not be.
"""
try:
if len(size) != tensor.ndim:
raise get_tensor_size_error(name, tensor, _get_tensor_size_str(size))
for d1, d2 in zip(tensor.size(), size, strict=True):
if d2 is not None and d1 != d2:
raise get_tensor_size_error(name, tensor, _get_tensor_size_str(size))
except AttributeError as exc:
raise get_type_error(name, tensor, "Tensor") from exc


def get_rotations(targets: Tensor) -> Tensor:
"""Get rotation matrices from (1,0,0) to target vectors.
Expand Down
2 changes: 1 addition & 1 deletion test/tests/torch/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import ramannoodle.io.generic as generic_io
from ramannoodle.pmodel.torch.dataset import PolarizabilityDataset
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion test/tests/torch/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from ramannoodle.pmodel.torch.utils import (
_radius_graph_pbc,
get_rotations,
polarizability_vectors_to_tensors,
)
from ramannoodle.dataset.torch.utils import polarizability_vectors_to_tensors


# import ramannoodle.io.vasp as vasp_io
# from ramannoodle.structure.structure_utils import apply_pbc
Expand Down

0 comments on commit baa8795

Please sign in to comment.