1
1
from __future__ import annotations
2
2
3
+ from typing import TYPE_CHECKING , Literal , Sequence
4
+
3
5
import matplotlib .pyplot as plt
4
6
import pandas as pd
5
7
import plotly .graph_objects as go
17
19
from pymatviz .utils import df_ptable
18
20
19
21
22
+ if TYPE_CHECKING :
23
+ from pymatviz .ptable import CountMode
24
+
25
+
20
26
@pytest .fixture
21
27
def glass_formulas () -> list [str ]:
22
28
"""Output of:
@@ -26,10 +32,10 @@ def glass_formulas() -> list[str]:
26
32
load_dataset("matbench_glass").composition.head(20)
27
33
"""
28
34
return (
29
- "Al, Al(NiB)2, Al10Co21B19, Al10Co23B17, Al10Co27B13, Al10Co29B11, Al10Co31B9, "
30
- "Al10Co33B7, Al10Cr3Si7, Al10Fe23B17, Al10Fe27B13, Al10Fe31B9, Al10Fe33B7, "
31
- "Al10Ni23B17, Al10Ni27B13, Al10Ni29B11, Al10Ni31B9, Al10Ni33B7, Al11(CrSi2)3"
32
- ).split ("," )
35
+ "Al Al(NiB)2 Al10Co21B19 Al10Co23B17 Al10Co27B13 Al10Co29B11 Al10Co31B9 "
36
+ "Al10Co33B7 Al10Cr3Si7 Al10Fe23B17 Al10Fe27B13 Al10Fe31B9 Al10Fe33B7 "
37
+ "Al10Ni23B17 Al10Ni27B13 Al10Ni29B11 Al10Ni31B9 Al10Ni33B7 Al11(CrSi2)3"
38
+ ).split ()
33
39
34
40
35
41
@pytest .fixture
@@ -64,13 +70,13 @@ def steel_elem_counts(steel_formulas: pd.Series[Composition]) -> pd.Series[int]:
64
70
("reduced_composition" , {"Fe" : 13 , "O" : 27 , "P" : 3 }),
65
71
],
66
72
)
67
- def test_count_elements (count_mode , counts ) :
73
+ def test_count_elements (count_mode : CountMode , counts : dict [ str , float ]) -> None :
68
74
series = count_elements (["Fe2 O3" ] * 5 + ["Fe4 P4 O16" ] * 3 , count_mode = count_mode )
69
75
expected = pd .Series (counts , index = df_ptable .index , name = "count" ).fillna (0 )
70
76
assert series .equals (expected )
71
77
72
78
73
- def test_count_elements_by_atomic_nums ():
79
+ def test_count_elements_by_atomic_nums () -> None :
74
80
series_in = pd .Series (1 , index = range (1 , 119 ))
75
81
el_cts = count_elements (series_in )
76
82
expected = pd .Series (1 , index = df_ptable .index , name = "count" )
@@ -79,7 +85,7 @@ def test_count_elements_by_atomic_nums():
79
85
80
86
81
87
@pytest .mark .parametrize ("range_limits" , [(- 1 , 10 ), (100 , 200 )])
82
- def test_count_elements_bad_atomic_nums (range_limits ) :
88
+ def test_count_elements_bad_atomic_nums (range_limits : tuple [ int , int ]) -> None :
83
89
with pytest .raises (ValueError , match = "assumed to represent atomic numbers" ):
84
90
count_elements ({idx : 0 for idx in range (* range_limits )})
85
91
@@ -88,7 +94,7 @@ def test_count_elements_bad_atomic_nums(range_limits):
88
94
count_elements ({str (idx ): 0 for idx in range (* range_limits )})
89
95
90
96
91
- def test_hist_elemental_prevalence (glass_formulas ) :
97
+ def test_hist_elemental_prevalence (glass_formulas : list [ str ]) -> None :
92
98
ax = hist_elemental_prevalence (glass_formulas )
93
99
assert isinstance (ax , plt .Axes )
94
100
@@ -99,7 +105,9 @@ def test_hist_elemental_prevalence(glass_formulas):
99
105
hist_elemental_prevalence (glass_formulas , keep_top = 10 , bar_values = "count" )
100
106
101
107
102
- def test_ptable_heatmap (glass_formulas , glass_elem_counts ):
108
+ def test_ptable_heatmap (
109
+ glass_formulas : list [str ], glass_elem_counts : pd .Series [int ]
110
+ ) -> None :
103
111
ax = ptable_heatmap (glass_formulas )
104
112
assert isinstance (ax , plt .Axes )
105
113
@@ -139,8 +147,11 @@ def test_ptable_heatmap(glass_formulas, glass_elem_counts):
139
147
140
148
141
149
def test_ptable_heatmap_ratio (
142
- steel_formulas , glass_formulas , steel_elem_counts , glass_elem_counts
143
- ):
150
+ steel_formulas : list [str ],
151
+ glass_formulas : list [str ],
152
+ steel_elem_counts : pd .Series [int ],
153
+ glass_elem_counts : pd .Series [int ],
154
+ ) -> None :
144
155
# composition strings
145
156
ax = ptable_heatmap_ratio (glass_formulas , steel_formulas )
146
157
assert isinstance (ax , plt .Axes )
@@ -153,7 +164,7 @@ def test_ptable_heatmap_ratio(
153
164
ptable_heatmap_ratio (glass_elem_counts , steel_formulas )
154
165
155
166
156
- def test_ptable_heatmap_plotly (glass_formulas ) :
167
+ def test_ptable_heatmap_plotly (glass_formulas : list [ str ]) -> None :
157
168
fig = ptable_heatmap_plotly (glass_formulas )
158
169
assert isinstance (fig , go .Figure )
159
170
assert (
@@ -194,10 +205,16 @@ def test_ptable_heatmap_plotly(glass_formulas):
194
205
)
195
206
@pytest .mark .parametrize ("showscale" , [False , True ])
196
207
@pytest .mark .parametrize ("font_size" , [None , 14 ])
197
- @pytest .mark .parametrize ("font_colors" , [None , ("black" , "white" )])
208
+ @pytest .mark .parametrize ("font_colors" , [[ "red" ] , ("black" , "white" )])
198
209
def test_ptable_heatmap_plotly_kwarg_combos (
199
- glass_formulas , exclude_elements , heat_mode , showscale , font_size , font_colors , log
200
- ):
210
+ glass_formulas : list [str ],
211
+ exclude_elements : Sequence [str ],
212
+ heat_mode : Literal ["value" , "fraction" , "percent" ] | None ,
213
+ showscale : bool ,
214
+ font_size : int ,
215
+ font_colors : tuple [str ] | tuple [str , str ],
216
+ log : bool ,
217
+ ) -> None :
201
218
fig = ptable_heatmap_plotly (
202
219
glass_formulas ,
203
220
exclude_elements = exclude_elements ,
@@ -211,8 +228,10 @@ def test_ptable_heatmap_plotly_kwarg_combos(
211
228
212
229
213
230
@pytest .mark .parametrize (
214
- "clr_scl " , ["YlGn" , ["blue" , "red" ], [(0 , "blue" ), (1 , "red" )]]
231
+ "colorscale " , ["YlGn" , ["blue" , "red" ], [(0 , "blue" ), (1 , "red" )]]
215
232
)
216
- def test_ptable_heatmap_plotly_colorscale (glass_formulas , clr_scl ):
217
- fig = ptable_heatmap_plotly (glass_formulas , colorscale = clr_scl )
233
+ def test_ptable_heatmap_plotly_colorscale (
234
+ glass_formulas : list [str ], colorscale : str | list [tuple [float , str ]] | list [str ]
235
+ ) -> None :
236
+ fig = ptable_heatmap_plotly (glass_formulas , colorscale = colorscale )
218
237
assert isinstance (fig , go .Figure )
0 commit comments