Skip to content

Commit

Permalink
Improvements to the Wyckoff manipulation functions (#83)
Browse files Browse the repository at this point in the history
* fix: remove the random 1::2 indexing of alternating lists to be clearer

* fix: cache wasn't available in earlier python versions

* fea: add function to get a random crystal from the protostructure string. add flaky tests but mark xfail until pyxtal can be made deterministic.

* fix: add pyxtal directly to test reqs

* test: check that the composition is the same but the structure is different

* fix: can UV resolve pyxtal?

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: --system flag for uv

* lint: ruff fixes

* fix: bump python to 3.9

* ffix: try numpy < 2
  • Loading branch information
CompRhys authored Jul 12, 2024
1 parent 2a7b51b commit ef7778c
Show file tree
Hide file tree
Showing 20 changed files with 210 additions and 53 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.8
python-version: 3.9
cache: pip
cache-dependency-path: pyproject.toml

- name: Install uv
run: pip install uv

- name: Install dependencies
run: |
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
pip install .[test]
pip install torch==2.2.1 --index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html
uv pip install .[test] --system
- name: Run Tests
run: pytest --capture=no --cov .
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/comprhys/aviary?label=Last+Commit)](https://github.com/comprhys/aviary/commits)
[![Tests](https://github.com/CompRhys/aviary/actions/workflows/test.yml/badge.svg)](https://github.com/CompRhys/aviary/actions/workflows/test.yml)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/CompRhys/aviary/main.svg)](https://results.pre-commit.ci/latest/github/CompRhys/aviary/main)
[![This project supports Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
[![This project supports Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)

</h4>

Expand All @@ -18,10 +18,10 @@ The aim of `aviary` is to contain multiple models for materials discovery under
Aviary requires [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter). `pip install` it with

```sh
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html
```

Make sure you replace `2.1.0` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable.
Make sure you replace `2.2.1` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable.

Then install `aviary` from source with

Expand Down
6 changes: 4 additions & 2 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import itertools
import json
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
Expand All @@ -14,6 +14,8 @@
from aviary import PKG_DIR

if TYPE_CHECKING:
from collections.abc import Sequence

import pandas as pd
from pymatgen.core import Structure

Expand Down Expand Up @@ -123,7 +125,7 @@ def __repr__(self) -> str:
return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})"

# Cache loaded structures
@functools.lru_cache(maxsize=None) # noqa: B019
@functools.cache # noqa: B019
def __getitem__(self, idx: int):
"""Get an entry out of the Dataset.
Expand Down
5 changes: 4 additions & 1 deletion aviary/cgcnn/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F
Expand All @@ -10,6 +10,9 @@
from aviary.core import BaseModelClass
from aviary.networks import SimpleNetwork

if TYPE_CHECKING:
from collections.abc import Sequence


class CrystalGraphConvNet(BaseModelClass):
"""Create a crystal graph convolutional neural network for predicting total
Expand Down
4 changes: 3 additions & 1 deletion aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
from abc import ABC
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping
from typing import TYPE_CHECKING, Any, Callable, Literal

import numpy as np
import torch
Expand All @@ -19,6 +19,8 @@
from aviary import ROOT

if TYPE_CHECKING:
from collections.abc import Mapping

from torch.utils.data import DataLoader

from aviary.data import InMemoryDataLoader
Expand Down
4 changes: 3 additions & 1 deletion aviary/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Iterator
from typing import TYPE_CHECKING, Callable

import numpy as np

if TYPE_CHECKING:
from collections.abc import Iterator

from torch import Tensor


Expand Down
5 changes: 4 additions & 1 deletion aviary/networks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import Sequence
from typing import TYPE_CHECKING

from torch import Tensor, nn

if TYPE_CHECKING:
from collections.abc import Sequence


class SimpleNetwork(nn.Module):
"""Simple Feed Forward Neural Network."""
Expand Down
6 changes: 4 additions & 2 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import json
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
Expand All @@ -13,6 +13,8 @@
from aviary import PKG_DIR

if TYPE_CHECKING:
from collections.abc import Sequence

import pandas as pd


Expand Down Expand Up @@ -74,7 +76,7 @@ def __repr__(self) -> str:
return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})"

# Cache data for faster training
@functools.lru_cache(maxsize=None) # noqa: B019
@functools.cache # noqa: B019
def __getitem__(self, idx: int):
"""Get an entry out of the Dataset.
Expand Down
5 changes: 4 additions & 1 deletion aviary/roost/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F
Expand All @@ -10,6 +10,9 @@
from aviary.networks import ResidualNetwork, SimpleNetwork
from aviary.segments import MessageLayer, WeightedAttentionPooling

if TYPE_CHECKING:
from collections.abc import Sequence


class Roost(BaseModelClass):
"""The Roost model is comprised of a fully connected network
Expand Down
5 changes: 4 additions & 1 deletion aviary/segments.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from typing import Sequence
from typing import TYPE_CHECKING

import torch
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_add, scatter_max

from aviary.networks import SimpleNetwork

if TYPE_CHECKING:
from collections.abc import Sequence


class AttentionPooling(nn.Module):
"""Softmax attention layer. Currently unused."""
Expand Down
3 changes: 2 additions & 1 deletion aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import contextmanager
from datetime import datetime
from pickle import PickleError
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal

import numpy as np
import pandas as pd
Expand All @@ -32,6 +32,7 @@
from aviary.losses import robust_l1_loss, robust_l2_loss

if TYPE_CHECKING:
from collections.abc import Generator, Iterable
from types import ModuleType


Expand Down
6 changes: 4 additions & 2 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import re
from itertools import groupby
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
Expand All @@ -15,6 +15,8 @@
from aviary.wren.utils import relab_dict, wyckoff_multiplicity_dict

if TYPE_CHECKING:
from collections.abc import Sequence

import pandas as pd


Expand Down Expand Up @@ -88,7 +90,7 @@ def __repr__(self) -> str:
df_repr = f"cols=[{', '.join(self.df.columns)}], len={len(self.df)}"
return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})"

@functools.lru_cache(maxsize=None) # noqa: B019
@functools.cache # noqa: B019
def __getitem__(self, idx: int):
"""Get an entry out of the Dataset.
Expand Down
5 changes: 4 additions & 1 deletion aviary/wren/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F
Expand All @@ -11,6 +11,9 @@
from aviary.networks import ResidualNetwork, SimpleNetwork
from aviary.segments import MessageLayer, WeightedAttentionPooling

if TYPE_CHECKING:
from collections.abc import Sequence


class Wren(BaseModelClass):
"""The Roost model is comprised of a fully connected network
Expand Down
Loading

0 comments on commit ef7778c

Please sign in to comment.