Skip to content

Commit 49fc74a

Browse files
author
Jason Munro
authored
Update thermo rester methods (#714)
* Update thermo rester methods * Flake 8 and docstring fix * More linting * Update thermo test * Fix thermo test * Temp xas test skip
1 parent fecf7fd commit 49fc74a

File tree

3 files changed

+43
-50
lines changed

3 files changed

+43
-50
lines changed

mp_api/client/routes/thermo.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import warnings
22
from collections import defaultdict
3-
from typing import List, Optional, Tuple, Union
4-
5-
from emmet.core.thermo import ThermoDoc
6-
from pymatgen.analysis.phase_diagram import PhaseDiagram
7-
3+
from typing import Optional, List, Tuple, Union
84
from mp_api.client.core import BaseRester
95
from mp_api.client.core.utils import validate_ids
6+
from emmet.core.thermo import ThermoDoc, ThermoType
7+
from pymatgen.analysis.phase_diagram import PhaseDiagram
108

119

1210
class ThermoRester(BaseRester[ThermoDoc]):
@@ -40,6 +38,8 @@ def search(
4038
is_stable: Optional[bool] = None,
4139
material_ids: Optional[List[str]] = None,
4240
num_elements: Optional[Tuple[int, int]] = None,
41+
thermo_ids: Optional[List[str]] = None,
42+
thermo_types: Optional[List[ThermoType]] = None,
4343
total_energy: Optional[Tuple[float, float]] = None,
4444
uncorrected_energy: Optional[Tuple[float, float]] = None,
4545
sort_fields: Optional[List[str]] = None,
@@ -63,6 +63,9 @@ def search(
6363
(e.g., [Fe2O3, ABO3]).
6464
is_stable (bool): Whether the material is stable.
6565
material_ids (List[str]): List of Materials Project IDs to return data for.
66+
thermo_ids (List[str]): List of thermo IDs to return data for. This is a combination of the Materials
67+
Project ID and thermo type (e.g. mp-149_GGA_GGA+U).
68+
thermo_types (List[ThermoType]): List of thermo types to return data for (e.g. ThermoType.GGA_GGA_U).
6669
num_elements (Tuple[int,int]): Minimum and maximum number of elements in the material to consider.
6770
total_energy (Tuple[float,float]): Minimum and maximum corrected total energy in eV/atom to consider.
6871
uncorrected_energy (Tuple[float,float]): Minimum and maximum uncorrected total
@@ -95,6 +98,14 @@ def search(
9598
if material_ids:
9699
query_params.update({"material_ids": ",".join(validate_ids(material_ids))})
97100

101+
if thermo_ids:
102+
query_params.update({"thermo_ids": ",".join(validate_ids(thermo_ids))})
103+
104+
if thermo_types:
105+
query_params.update(
106+
{"thermo_types": ",".join([t.value for t in thermo_types])}
107+
)
108+
98109
if num_elements:
99110
if isinstance(num_elements, int):
100111
num_elements = (num_elements, num_elements)
@@ -141,19 +152,23 @@ def search(
141152
**query_params,
142153
)
143154

144-
def get_phase_diagram_from_chemsys(self, chemsys: str) -> PhaseDiagram:
155+
def get_phase_diagram_from_chemsys(
156+
self, chemsys: str, thermo_type: ThermoType = ThermoType.GGA_GGA_U
157+
) -> PhaseDiagram:
145158
"""
146159
Get a pre-computed phase diagram for a given chemsys.
147160
148161
Arguments:
149-
material_id (str): Materials project ID
162+
chemsys (str): A chemical system (e.g. Li-Fe-O)
163+
thermo_type (ThermoType): The thermo type for the phase diagram.
164+
Defaults to ThermoType.GGA_GGA_U.
150165
Returns:
151166
phase_diagram (PhaseDiagram): Pymatgen phase diagram object.
152167
"""
153-
168+
phase_diagram_id = f"{chemsys}_{thermo_type.value}"
154169
response = self._query_resource(
155170
fields=["phase_diagram"],
156-
suburl=f"phase_diagram/{chemsys}",
171+
suburl=f"phase_diagram/{phase_diagram_id}",
157172
use_document_model=False,
158173
num_chunks=1,
159174
chunk_size=1,

tests/test_thermo.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pymatgen.analysis.phase_diagram import PhaseDiagram
66

77
from mp_api.client.routes.thermo import ThermoRester
8+
from emmet.core.thermo import ThermoType
89

910

1011
@pytest.fixture
@@ -21,13 +22,15 @@ def rester():
2122
"all_fields",
2223
"fields",
2324
"equilibrium_reaction_energy",
25+
"thermo_types",
2426
]
2527

2628
sub_doc_fields = [] # type: list
2729

2830
alt_name_dict = {
2931
"formula": "formula_pretty",
3032
"material_ids": "material_id",
33+
"thermo_ids": "thermo_id",
3134
"total_energy": "energy_per_atom",
3235
"formation_energy": "formation_energy_per_atom",
3336
"uncorrected_energy": "uncorrected_energy_per_atom",
@@ -40,12 +43,11 @@ def rester():
4043
"material_ids": ["mp-149"],
4144
"formula": "SiO2",
4245
"chemsys": "Si-O",
46+
"thermo_ids": ["mp-149"],
4347
} # type: dict
4448

4549

46-
@pytest.mark.skipif(
47-
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
48-
)
50+
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
4951
def test_client(rester):
5052
search_method = rester.search
5153

@@ -56,6 +58,7 @@ def test_client(rester):
5658
# Query API for each numeric and boolean parameter and check if returned
5759
for entry in param_tuples:
5860
param = entry[0]
61+
print(param)
5962
if param not in excluded_params:
6063
param_type = entry[1].__args__[0]
6164
q = None
@@ -66,55 +69,42 @@ def test_client(rester):
6669
param: (-100, 100),
6770
"chunk_size": 1,
6871
"num_chunks": 1,
69-
"fields": [
70-
project_field if project_field is not None else param
71-
],
72+
"fields": [project_field if project_field is not None else param],
7273
}
7374
elif param_type == typing.Tuple[float, float]:
7475
project_field = alt_name_dict.get(param, None)
7576
q = {
7677
param: (-100.12, 100.12),
7778
"chunk_size": 1,
7879
"num_chunks": 1,
79-
"fields": [
80-
project_field if project_field is not None else param
81-
],
80+
"fields": [project_field if project_field is not None else param],
8281
}
8382
elif param_type is bool:
8483
project_field = alt_name_dict.get(param, None)
8584
q = {
8685
param: False,
8786
"chunk_size": 1,
8887
"num_chunks": 1,
89-
"fields": [
90-
project_field if project_field is not None else param
91-
],
88+
"fields": [project_field if project_field is not None else param],
9289
}
9390
elif param in custom_field_tests:
9491
project_field = alt_name_dict.get(param, None)
9592
q = {
9693
param: custom_field_tests[param],
9794
"chunk_size": 1,
9895
"num_chunks": 1,
99-
"fields": [
100-
project_field if project_field is not None else param
101-
],
96+
"fields": [project_field if project_field is not None else param],
10297
}
10398

10499
doc = search_method(**q)[0].dict()
105100
for sub_field in sub_doc_fields:
106101
if sub_field in doc:
107102
doc = doc[sub_field]
108103

109-
assert (
110-
doc[project_field if project_field is not None else param]
111-
is not None
112-
)
104+
assert doc[project_field if project_field is not None else param] is not None
113105

114106

115107
def test_get_phase_diagram_from_chemsys():
116108
# Test that a phase diagram is returned
117109

118-
assert isinstance(
119-
ThermoRester().get_phase_diagram_from_chemsys("Hf-Pm"), PhaseDiagram
120-
)
110+
assert isinstance(ThermoRester().get_phase_diagram_from_chemsys("Hf-Pm"), PhaseDiagram)

tests/test_xas.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ def rester():
4242
} # type: dict
4343

4444

45-
@pytest.mark.skipif(
46-
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
47-
)
45+
@pytest.mark.skip(reason="Temp skip until timeout update.")
46+
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
4847
def test_client(rester):
4948
search_method = rester.search
5049

@@ -65,47 +64,36 @@ def test_client(rester):
6564
param: (-100, 100),
6665
"chunk_size": 1,
6766
"num_chunks": 1,
68-
"fields": [
69-
project_field if project_field is not None else param
70-
],
67+
"fields": [project_field if project_field is not None else param],
7168
}
7269
elif param_type == typing.Tuple[float, float]:
7370
project_field = alt_name_dict.get(param, None)
7471
q = {
7572
param: (-100.12, 100.12),
7673
"chunk_size": 1,
7774
"num_chunks": 1,
78-
"fields": [
79-
project_field if project_field is not None else param
80-
],
75+
"fields": [project_field if project_field is not None else param],
8176
}
8277
elif param_type is bool:
8378
project_field = alt_name_dict.get(param, None)
8479
q = {
8580
param: False,
8681
"chunk_size": 1,
8782
"num_chunks": 1,
88-
"fields": [
89-
project_field if project_field is not None else param
90-
],
83+
"fields": [project_field if project_field is not None else param],
9184
}
9285
elif param in custom_field_tests:
9386
project_field = alt_name_dict.get(param, None)
9487
q = {
9588
param: custom_field_tests[param],
9689
"chunk_size": 1,
9790
"num_chunks": 1,
98-
"fields": [
99-
project_field if project_field is not None else param
100-
],
91+
"fields": [project_field if project_field is not None else param],
10192
}
10293

10394
doc = search_method(**q)[0].dict()
10495
for sub_field in sub_doc_fields:
10596
if sub_field in doc:
10697
doc = doc[sub_field]
10798

108-
assert (
109-
doc[project_field if project_field is not None else param]
110-
is not None
111-
)
99+
assert doc[project_field if project_field is not None else param] is not None

0 commit comments

Comments
 (0)