From 6a6a4c24466f8fa4dbc177d14ead8c5c1e44bd1d Mon Sep 17 00:00:00 2001 From: wolearyc Date: Tue, 17 Sep 2024 13:12:55 -0700 Subject: [PATCH] added [torch] install option --- .github/workflows/python-package.yml | 1 + README.md | 16 +++++-- deps/dev_requirements.txt | 5 ++ deps/requirements.txt | 31 ++++--------- deps/torch_geometric_requirements.txt | 36 +++++++-------- deps/torch_requirements.txt | 6 +-- pyproject.toml | 46 +++++++++++-------- ramannoodle/io/generic.py | 8 +++- ramannoodle/io/io_utils.py | 14 ++++-- ramannoodle/io/vasp/outcar.py | 8 +++- ramannoodle/io/vasp/vasprun.py | 8 +++- ramannoodle/polarizability/torch/dataset.py | 2 + .../polarizability/torch/dummy_dataset.py | 44 ++++++++++++++++++ 13 files changed, 151 insertions(+), 74 deletions(-) create mode 100644 deps/dev_requirements.txt create mode 100644 ramannoodle/polarizability/torch/dummy_dataset.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index e40f079..c96b747 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -27,6 +27,7 @@ jobs: run: | python -m pip install --upgrade pip pip install uv + uv pip install ${{ matrix.uv-arg }} --system -r deps/dev_requirements.txt uv pip install ${{ matrix.uv-arg }} --system -r deps/requirements.txt uv pip install ${{ matrix.uv-arg }} --system -r deps/torch_geometric_requirements.txt uv pip install ${{ matrix.uv-arg }} --system -r deps/torch_requirements.txt diff --git a/README.md b/README.md index 0eaa427..df7aa1e 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **Ramannoodle** is a Python API for efficiently calculating Raman spectra from first principles calculations. Ramannoodle supports molecular-dynamics- and phonon-based Raman calculations and includes interfaces with VASP. -Ramannoodle is designed from the ground up to be: +Ramannoodle aims to be: 1. **EFFICIENT** @@ -33,9 +33,19 @@ Ramannoodle includes interfaces with: Ramannoodle can be installed via pip: -` +``` $ pip install ramannoodle -` +``` + +Due to idiosyncrasies with PyTorch's build system, installing ramannoodle's machine learning modules is slightly more involved. First, PyTorch must be installed ([pip commands](https://pytorch.org/get-started/locally/)). Then, corresponding torch-scatter and torch-sparse packages must be installed. Finally, Ramannoodle can then be installed with the appropriate options. + +For example, installation on a Linux system using PyTorch 2.4.1 (cpu implementation) is done as follows: + +``` +$ pip install torch==2.4.1+cpu --index-url https://download.pytorch.org/whl/cpu +$ pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cpu.html +$ pip install ramannoodle[torch] +``` ## Documentation diff --git a/deps/dev_requirements.txt b/deps/dev_requirements.txt new file mode 100644 index 0000000..bd96f79 --- /dev/null +++ b/deps/dev_requirements.txt @@ -0,0 +1,5 @@ +flake8 == 7.1.0 +pre-commit == 3.7.1 +pylint == 3.2.6 +pytest == 8.3.1 +setuptools == 74.1.2 diff --git a/deps/requirements.txt b/deps/requirements.txt index 0349fcf..e3d7145 100644 --- a/deps/requirements.txt +++ b/deps/requirements.txt @@ -1,23 +1,12 @@ # numpy, scipy recommendations: https://scientific-python.org/specs/spec-0000/ -defusedxml >= 0.6.0;python_version=='3.10' # minimum working -defusedxml >= 0.6.0;python_version=='3.11' # minimum working -defusedxml >= 0.6.0;python_version=='3.12' # minimum working -flake8 == 7.1.0 -numpy >= 1.24.0;python_version=='3.10' # minimum recommended -numpy >= 1.24.0;python_version=='3.11' # minimum recommended -numpy >= 1.26.0;python_version=='3.12' # minimum working -pre-commit == 3.7.1 -pylint == 3.2.6 -pytest == 8.3.1 -scipy >= 1.10.0;python_version=='3.10' # minimum recommended -scipy >= 1.10.0;python_version=='3.11' # minimum recommended -scipy >= 1.11.2;python_version=='3.12' # minimum working -setuptools == 74.1.2 -spglib >= 1.16.4;python_version=='3.10' # minimum working -spglib >= 1.16.4;python_version=='3.11' # minimum working -spglib >= 1.16.4;python_version=='3.12' # minimum working -tabulate >= 0.8.8;python_version=='3.10' # minimum working -tabulate >= 0.8.8;python_version=='3.11' # minimum working -tabulate >= 0.8.8;python_version=='3.12' # minimum working -tqdm >= 2.0 +defusedxml >= 0.6.0 # min working +numpy >= 1.24.0;python_version=='3.10' # min recommended +numpy >= 1.24.0;python_version=='3.11' # min recommended +numpy >= 1.26.0;python_version=='3.12' # min working +scipy >= 1.10.0;python_version=='3.10' # min recommended +scipy >= 1.10.0;python_version=='3.11' # min recommended +scipy >= 1.11.2;python_version=='3.12' # min working +spglib >= 1.16.4 # min working +tabulate >= 0.8.8 # min working +tqdm >= 2.0 # min working diff --git a/deps/torch_geometric_requirements.txt b/deps/torch_geometric_requirements.txt index 9339aac..44bd803 100644 --- a/deps/torch_geometric_requirements.txt +++ b/deps/torch_geometric_requirements.txt @@ -1,20 +1,16 @@ -aiohttp >= 3.8.0;python_version=='3.10' -aiohttp >= 3.8.3;python_version=='3.11' -aiohttp >= 3.9.0;python_version=='3.12' -dill >= 0.3.4 -frozenlist >= 1.2.0;python_version=='3.10' -frozenlist >= 1.3.3;python_version=='3.11' -frozenlist >= 1.4.1;python_version=='3.12' -fsspec>= 2021.4.0;python_version=='3.10' -fsspec>= 2021.4.0;python_version=='3.11' -fsspec>=2021.4.0;python_version=='3.12' -jinja2 >= 3.0.2 -pyparsing >= 3.0.0 -scikit-learn >= 1.2.0;python_version=='3.10' -scikit-learn >= 1.2.0;python_version=='3.11' -scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=="darwin" -scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=="linux" -scikit-learn >= 1.4.0;python_version=='3.12' and sys_platform=="win32" -torch_geometric >= 2.3.0;python_version=='3.10' -torch_geometric >= 2.3.0;python_version=='3.11' -torch_geometric >= 2.3.0;python_version=='3.12' +aiohttp >= 3.8.0;python_version=='3.10' # min working +aiohttp >= 3.8.3;python_version=='3.11' # min working +aiohttp >= 3.9.0;python_version=='3.12' # min working +dill >= 0.3.4 # min working +frozenlist >= 1.2.0;python_version=='3.10' # min working +frozenlist >= 1.3.3;python_version=='3.11' # min working +frozenlist >= 1.4.1;python_version=='3.12' # min working +fsspec>= 2021.4.0;python_version=='3.10' # min working +jinja2 >= 3.0.2 # min working +pyparsing >= 3.0.0 # min working +scikit-learn >= 1.2.0;python_version=='3.10' # min working +scikit-learn >= 1.2.0;python_version=='3.11' # min working +scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='darwin' # min working +scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='linux' # min working +scikit-learn >= 1.4.0;python_version=='3.12' and sys_platform=='win32' # min working +torch_geometric >= 2.3.0 # min working diff --git a/deps/torch_requirements.txt b/deps/torch_requirements.txt index 04e31d5..838b208 100644 --- a/deps/torch_requirements.txt +++ b/deps/torch_requirements.txt @@ -1,4 +1,4 @@ --index-url https://download.pytorch.org/whl/cpu -torch==2.4.1;sys_platform=="darwin" -torch==2.4.1+cpu;sys_platform=="linux" -torch==2.4.1+cpu;sys_platform=="win32" +torch==2.4.1;sys_platform=='darwin' +torch==2.4.1+cpu;sys_platform=='linux' +torch==2.4.1+cpu;sys_platform=='win32' diff --git a/pyproject.toml b/pyproject.toml index be88ad4..18edb34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,27 +14,37 @@ requires-python = ">=3.10" keywords = ["raman", "spectrum", "vasp", "dft", "phonons", "molecular", "dynamics", "polarizability" ] license = {text = "MIT"} dependencies = [ - "numpy >= 1.24.0;python_version=='3.10'", # minimum recommended - "numpy >= 1.24.0;python_version=='3.11'", # minimum recommended - "numpy >= 1.26.0;python_version=='3.12'", # minimum working - "scipy >= 1.10.0;python_version=='3.10'", # minimum recommended - "scipy >= 1.10.0;python_version=='3.11'", # minimum recommended - "scipy >= 1.11.2;python_version=='3.12'", # minimum working - "spglib >= 1.16.4;python_version=='3.10'", # minimum working - "spglib >= 1.16.4;python_version=='3.11'", # minimum working - "spglib >= 1.16.4;python_version=='3.12'", # minimum working - "defusedxml >= 0.6.0;python_version=='3.10'", # minimum working - "defusedxml >= 0.6.0;python_version=='3.11'", # minimum working - "defusedxml >= 0.6.0;python_version=='3.12'", # minimum working - "tabulate >= 0.8.8;python_version=='3.10'", # minimum working - "tabulate >= 0.8.8;python_version=='3.11'", # minimum working - "tabulate >= 0.8.8;python_version=='3.12'", # minimum working - "torch >= 2.4.0", - "torch-geometric >= 2.5.3", - "torch-sparse >= 0.6.18", + "defusedxml >= 0.6.0", # min working + "numpy >= 1.24.0;python_version=='3.10'", # min recommended + "numpy >= 1.24.0;python_version=='3.11'", # min recommended + "numpy >= 1.26.0;python_version=='3.12'", # min working + "scipy >= 1.10.0;python_version=='3.10'", # min recommended + "scipy >= 1.10.0;python_version=='3.11'", # min recommended + "scipy >= 1.11.2;python_version=='3.12'", # min working + "spglib >= 1.16.4", # min working + "tabulate >= 0.8.8", # min working + "tqdm >= 2.0", # min working ] [project.optional-dependencies] +torch = [ + "aiohttp >= 3.8.0;python_version=='3.10'", # min working + "aiohttp >= 3.8.3;python_version=='3.11'", # min working + "aiohttp >= 3.9.0;python_version=='3.12'", # min working + "dill >= 0.3.4", # min working + "frozenlist >= 1.2.0;python_version=='3.10'", # min working + "frozenlist >= 1.3.3;python_version=='3.11'", # min working + "frozenlist >= 1.4.1;python_version=='3.12'", # min working + "fsspec>= 2021.4.0", # min working + "jinja2 >= 3.0.2", # min working + "pyparsing >= 3.0.0", # min working + "scikit-learn >= 1.2.0;python_version=='3.10'", # min working + "scikit-learn >= 1.2.0;python_version=='3.11'", # min working + "scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='darwin'", # min working + "scikit-learn >= 1.3.0;python_version=='3.12' and sys_platform=='linux'", # min working + "scikit-learn >= 1.4.0;python_version=='3.12' and sys_platform=='win32'", # min working + "torch_geometric >= 2.3.0", # min working +] [project.urls] Documentation = "https://ramannoodle.readthedocs.io/en/latest/" diff --git a/ramannoodle/io/generic.py b/ramannoodle/io/generic.py index b3cfe3a..c9c0d3a 100644 --- a/ramannoodle/io/generic.py +++ b/ramannoodle/io/generic.py @@ -18,7 +18,11 @@ from ramannoodle.structure.reference import ReferenceStructure import ramannoodle.io.vasp as vasp_io -from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset + +try: + from ramannoodle.polarizability.torch import dataset +except ModuleNotFoundError: + import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore # These map between file formats and appropriate IO functions. _PHONON_READERS = { @@ -189,7 +193,7 @@ def read_structure_and_polarizability( def read_polarizability_dataset( filepaths: str | Path | list[str] | list[Path], file_format: str, -) -> PolarizabilityDataset: +) -> dataset.PolarizabilityDataset: """Read polarizability dataset from files. Parameters diff --git a/ramannoodle/io/io_utils.py b/ramannoodle/io/io_utils.py index e99d58f..f2c4f72 100644 --- a/ramannoodle/io/io_utils.py +++ b/ramannoodle/io/io_utils.py @@ -15,7 +15,11 @@ IncompatibleStructureException, ) from ramannoodle.globals import ATOM_SYMBOLS -from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset + +try: + from ramannoodle.polarizability.torch import dataset +except ModuleNotFoundError: + import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore def _skip_file_until_line_contains(file: TextIO, content: str) -> str: @@ -95,7 +99,7 @@ def _read_polarizability_dataset( [str | Path], tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]], ], -) -> PolarizabilityDataset: +) -> dataset.PolarizabilityDataset: """Read polarizability dataset from OUTCAR files. Parameters @@ -114,7 +118,11 @@ def _read_polarizability_dataset( File has an unexpected format. IncompatibleFileException File is incompatible with the dataset. + ModuleNotFoundError + Torch installation could not be found. """ + if not dataset.TORCH_PRESENT: + raise ModuleNotFoundError("torch installation not found") filepaths = pathify_as_list(filepaths) lattices: list[NDArray[np.float64]] = [] @@ -143,7 +151,7 @@ def _read_polarizability_dataset( positions_list.append(positions) polarizabilities.append(polarizability) - return PolarizabilityDataset( + return dataset.PolarizabilityDataset( np.array(lattices), atomic_numbers_list, np.array(positions_list), diff --git a/ramannoodle/io/vasp/outcar.py b/ramannoodle/io/vasp/outcar.py index 0781e60..ed89e7a 100644 --- a/ramannoodle/io/vasp/outcar.py +++ b/ramannoodle/io/vasp/outcar.py @@ -16,7 +16,11 @@ from ramannoodle.dynamics.phonon import Phonons from ramannoodle.dynamics.trajectory import Trajectory from ramannoodle.structure.reference import ReferenceStructure -from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset + +try: + from ramannoodle.polarizability.torch import dataset +except ModuleNotFoundError: + import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore # Utilities for OUTCAR. Warning: some of these functions partially read files. @@ -400,7 +404,7 @@ def read_structure_and_polarizability( def read_polarizability_dataset( filepaths: str | Path | list[str] | list[Path], -) -> PolarizabilityDataset: +) -> dataset.PolarizabilityDataset: """Read polarizability dataset from OUTCAR files. Parameters diff --git a/ramannoodle/io/vasp/vasprun.py b/ramannoodle/io/vasp/vasprun.py index 742f860..a028a14 100644 --- a/ramannoodle/io/vasp/vasprun.py +++ b/ramannoodle/io/vasp/vasprun.py @@ -14,7 +14,11 @@ from ramannoodle.dynamics.phonon import Phonons from ramannoodle.dynamics.trajectory import Trajectory from ramannoodle.structure.reference import ReferenceStructure -from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset + +try: + from ramannoodle.polarizability.torch import dataset +except ModuleNotFoundError: + import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore def _get_root_element(file: TextIO) -> Element: @@ -195,7 +199,7 @@ def read_structure_and_polarizability( def read_polarizability_dataset( filepaths: str | Path | list[str] | list[Path], -) -> PolarizabilityDataset: +) -> dataset.PolarizabilityDataset: """Read polarizability dataset from OUTCAR files. Parameters diff --git a/ramannoodle/polarizability/torch/dataset.py b/ramannoodle/polarizability/torch/dataset.py index b46a868..64cf3ba 100644 --- a/ramannoodle/polarizability/torch/dataset.py +++ b/ramannoodle/polarizability/torch/dataset.py @@ -12,6 +12,8 @@ from ramannoodle.exceptions import verify_ndarray_shape, verify_list_len, get_type_error import ramannoodle.polarizability.torch.utils as rn_torch_utils +TORCH_PRESENT = True + def _scale_and_flatten_polarizabilities( polarizabilities: Tensor, diff --git a/ramannoodle/polarizability/torch/dummy_dataset.py b/ramannoodle/polarizability/torch/dummy_dataset.py new file mode 100644 index 0000000..f22b67f --- /dev/null +++ b/ramannoodle/polarizability/torch/dummy_dataset.py @@ -0,0 +1,44 @@ +"""Dummy polarizability PyTorch dataset. + +Used when torch installation cannot be found. + +:meta private: +""" + +import numpy as np +from numpy.typing import NDArray + +TORCH_PRESENT = False + + +class PolarizabilityDataset: # pylint: disable=too-few-public-methods + """PyTorch dataset of atomic structures and polarizabilities. + + Polarizabilities are scaled and flattened into vectors containing the six + independent tensor components. + + Parameters + ---------- + lattices + | (Å) 3D array with shape (S,3,3) where S is the number of samples. + atomic_numbers + | List of length S containing lists of length N, where N is the number of atoms. + positions + | (fractional) 3D array with shape (S,N,3). + polarizabilities + | 3D array with shape (S,3,3). + scale_mode + | Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by + | standard deviation), and ``"none"`` (no scaling). + + """ + + def __init__( # pylint: disable=too-many-arguments + self, + lattices: NDArray[np.float64], + atomic_numbers: list[list[int]], + positions: NDArray[np.float64], + polarizabilities: NDArray[np.float64], + scale_mode: str = "standard", + ): + raise ModuleNotFoundError("torch installation not found")