Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
926 changes: 926 additions & 0 deletions docs/tutorials/element_embeddings_integration.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions smact/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""External packages."""
113 changes: 113 additions & 0 deletions smact/io/elementembeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Interface to ElementEmbeddings. See https://github.com/WMD-group/ElementEmbeddings."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

from elementembeddings.composition import composition_featuriser as ee_composition_featuriser
from elementembeddings.composition import species_composition_featuriser as ee_species_composition_featuriser

try:
from enum import StrEnum
except ImportError:

class StrEnum(str, Enum):
"""Backport of Python 3.11's StrEnum for Python 3.10."""

def __str__(self):
return str(self.value)


if TYPE_CHECKING:
import pandas as pd
from elementembeddings.composition import CompositionalEmbedding
from elementembeddings.core import Embedding


# Should be moved to element embeddings codebase
class AllowedElementEmbeddings(StrEnum):
"""ElementEmbeddings implemented in ElementEmbeddings."""

magpie = "magpie"
mat2vec = "mat2vec"
skipatom = "skipatom"
cgnf = "cgnf"
xenonpy = "xenonpy"
random = "random"
oliynyk = "oliynyk"
matscholar = "matscholar"
crystallm = "crystallm"
megnet = "megnet"


class AllowedSpeciesEmbeddings(StrEnum):
"""Allowed Species Embeddings."""

skipspecies = "skipspecies"


# Should be moved to element embeddings codebase
class PoolingStats(StrEnum):
"""Pooling statistical operations."""

mean = "mean"
variance = "variance"
minpool = "minpool"
maxpool = "maxpool"
range = "range"
sum = "sum"
geometric_mean = "geometric_mean"
harmonic_mean = "harmonic_mean"


def composition_featuriser(
composition_data: pd.DataFrame | pd.Series | CompositionalEmbedding | list,
formula_column: str = "formula",
embedding: Embedding | AllowedElementEmbeddings = AllowedElementEmbeddings.magpie,
stats: PoolingStats | list[PoolingStats] = PoolingStats.mean,
inplace: bool = False,
) -> pd.DataFrame:
"""Wrapper to `composition_featuriser` in ElementEmbeddings."""
return ee_composition_featuriser(
data=composition_data,
formula_column=formula_column,
embedding=embedding,
stats=stats,
inplace=inplace,
)


def species_composition_featuriser(
composition_data: list[dict[str, float]],
embedding: AllowedSpeciesEmbeddings | str = AllowedSpeciesEmbeddings.skipspecies,
stats: PoolingStats | list[PoolingStats] = PoolingStats.mean,
to_dataframe: bool = False,
) -> list | pd.DataFrame:
"""Compute a feature vector for a composition.

The feature vector is based on the statistics specified
in the stats argument.

Args:
----
composition_data: list[dict[str, float]]:
a list of composition dictionaries
embedding (Union[AllowedSpeciesEmbeddings, str], optional): An AllowedSpeciesEmbeddings class
or a string
stats (Union[str, list], optional): A list of statistics to be computed.
The default is ['mean'].
to_dataframe (bool, optional): Whether to return the feature vectors
as a DataFrame. The default is False.

Returns:
-------
Union[pd.DataFrame,list]: A pandas DataFrame containing the feature vector,
or a list of feature vectors is returned
"""
return ee_species_composition_featuriser(
data=composition_data,
embedding=embedding,
stats=stats,
to_dataframe=to_dataframe,
)
44 changes: 41 additions & 3 deletions smact/screening.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import itertools
import os
import warnings
from enum import Enum
from itertools import combinations
from typing import TYPE_CHECKING

Expand All @@ -17,11 +18,31 @@
lookup_element_oxidation_states_custom as oxi_custom,
)
from smact.metallicity import metallicity_score
from smact.utils.composition import composition_dict_maker, formula_maker

try:
from enum import StrEnum
except ImportError:

class StrEnum(str, Enum):
"""Backport of Python 3.11's StrEnum for Python 3.10."""

def __str__(self):
return str(self.value)


if TYPE_CHECKING:
import pymatgen


class SmactFilterOutputs(StrEnum):
"""Allowed outputs of the `smact_filter` function."""

default = "default"
formula = "formula"
dict = "dict"


def pauling_test(
oxidation_states: list[int],
electronegativities: list[float],
Expand Down Expand Up @@ -324,7 +345,8 @@ def smact_filter(
stoichs: list[list[int]] | None = None,
species_unique: bool = True,
oxidation_states_set: str = "icsd24",
) -> list[tuple[str, int, int]] | list[tuple[str, int]]:
return_output: SmactFilterOutputs = SmactFilterOutputs.default,
) -> list[tuple[str, int, int]] | list[tuple[str, int]] | list[str] | list[dict]:
"""Function that applies the charge neutrality and electronegativity
tests in one go for simple application in external scripts that
wish to apply the general 'smact test'.
Expand All @@ -340,6 +362,7 @@ def smact_filter(
stoichs (list[int]): A selection of valid stoichiometric ratios for each site.
species_unique (bool): Whether or not to consider elements in different oxidation states as unique in the results.
oxidation_states_set (string): A string to choose which set of oxidation states should be chosen. Options are 'smact14', 'icsd16',"icsd24", 'pymatgen_sp' and 'wiki' for the 2014 SMACT default, 2016 ICSD, 2024 ICSD, pymatgen structure predictor and Wikipedia (https://en.wikipedia.org/wiki/Template:List_of_oxidation_states_of_the_elements) oxidation states respectively. A filepath to an oxidation states text file can also be supplied as well.
return_output (SmactFilterOutputs): If set to 'default', the function will return a list of tuples containing the tuples of symbols, oxidation states and stoichiometry values. "Formula" returns a list of formulas and "dict" returns a list of dictionaries.

Returns:
-------
Expand Down Expand Up @@ -411,10 +434,25 @@ def smact_filter(
# Return list depending on whether we are interested in unique species combinations
# or just unique element combinations.
if species_unique:
return compositions
match return_output:
case SmactFilterOutputs.default:
return compositions
case SmactFilterOutputs.formula:
return [formula_maker(smact_filter_output=comp) for comp in compositions]
case SmactFilterOutputs.dict:
return [composition_dict_maker(smact_filter_output=comp) for comp in compositions]

else:
compositions = [(i[0], i[2]) for i in compositions]
return list(set(compositions))

compositions = list(set(compositions))
match return_output:
case SmactFilterOutputs.default:
return compositions
case SmactFilterOutputs.formula:
return [formula_maker(smact_filter_output=comp) for comp in compositions]
case SmactFilterOutputs.dict:
return [composition_dict_maker(smact_filter_output=comp) for comp in compositions]


# ---------------------------------------------------------------------
Expand Down
12 changes: 12 additions & 0 deletions smact/utils/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,15 @@ def formula_maker(smact_filter_output: tuple[str, int, int] | tuple[str, int]) -

"""
return comp_maker(smact_filter_output).reduced_formula


def composition_dict_maker(smact_filter_output: tuple[str, int, int] | tuple[str, int]) -> dict:
"""Convert an item in the output of smact.screening.smact_filter into a composition dictionary.

Args:
smact_filter_output (tuple[str, int, int]): An item in the list returned from smact_filter

Returns:
composition_dict (dict[str, float]): An composition dictionary
"""
return comp_maker(smact_filter_output).as_dict()
Loading
Loading