Skip to content

Commit

Permalink
fea: add due.dcite to address #67
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Aug 26, 2024
1 parent ee35ac5 commit 9359bdc
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aviary/cgcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_add, scatter_mean

Expand All @@ -14,6 +15,7 @@
from collections.abc import Sequence


@due.dcite(Doi("10.1103/PhysRevLett.120.145301"), description="CGCNN model")
class CrystalGraphConvNet(BaseModelClass):
"""Create a crystal graph convolutional neural network for predicting total
material properties.
Expand Down
2 changes: 2 additions & 0 deletions aviary/roost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn

from aviary.core import BaseModelClass
Expand All @@ -14,6 +15,7 @@
from collections.abc import Sequence


@due.dcite(Doi("10.1038/s41467-020-19964-7"), description="Roost model")
class Roost(BaseModelClass):
"""The Roost model is comprised of a fully connected network
and message passing graph layers.
Expand Down
2 changes: 2 additions & 0 deletions aviary/wren/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_mean

Expand All @@ -15,6 +16,7 @@
from collections.abc import Sequence


@due.dcite(Doi("10.1126/sciadv.abn4117"), description="Wren model")
class Wren(BaseModelClass):
"""The Roost model is comprised of a fully connected network
and message passing graph layers.
Expand Down
3 changes: 3 additions & 0 deletions aviary/wrenformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import BoolTensor, Tensor, nn

from aviary.core import BaseModelClass, masked_max, masked_mean, masked_min, masked_std
Expand All @@ -13,6 +14,8 @@
from collections.abc import Sequence


@due.dcite(Doi("10.48550/arXiv.2308.14920"), description="Wrenformer model")
@due.dcite(Doi("10.1038/s41524-021-00545-1"), description="Crabnet model")
class Wrenformer(BaseModelClass):
"""Crabnet-inspired re-implementation of Wren as a transformer.
https://github.com/anthony-wang/CrabNet.
Expand Down

0 comments on commit 9359bdc

Please sign in to comment.