Skip to content

Commit 0efb2df

Browse files
Move _infer_meta_data and _parse_size to utils (#6779)
* Move dataset plot functions to utils * move parse_size * move markersize * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4aae7fd commit 0efb2df

File tree

2 files changed

+129
-129
lines changed

2 files changed

+129
-129
lines changed

xarray/plot/dataset_plot.py

Lines changed: 2 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -10,98 +10,12 @@
1010
from .utils import (
1111
_add_colorbar,
1212
_get_nice_quiver_magnitude,
13-
_is_numeric,
13+
_infer_meta_data,
14+
_parse_size,
1415
_process_cmap_cbar_kwargs,
1516
get_axis,
16-
label_from_attrs,
1717
)
1818

19-
# copied from seaborn
20-
_MARKERSIZE_RANGE = np.array([18.0, 72.0])
21-
22-
23-
def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
24-
dvars = set(ds.variables.keys())
25-
error_msg = " must be one of ({:s})".format(", ".join(dvars))
26-
27-
if x not in dvars:
28-
raise ValueError("x" + error_msg)
29-
30-
if y not in dvars:
31-
raise ValueError("y" + error_msg)
32-
33-
if hue is not None and hue not in dvars:
34-
raise ValueError("hue" + error_msg)
35-
36-
if hue:
37-
hue_is_numeric = _is_numeric(ds[hue].values)
38-
39-
if hue_style is None:
40-
hue_style = "continuous" if hue_is_numeric else "discrete"
41-
42-
if not hue_is_numeric and (hue_style == "continuous"):
43-
raise ValueError(
44-
f"Cannot create a colorbar for a non numeric coordinate: {hue}"
45-
)
46-
47-
if add_guide is None or add_guide is True:
48-
add_colorbar = True if hue_style == "continuous" else False
49-
add_legend = True if hue_style == "discrete" else False
50-
else:
51-
add_colorbar = False
52-
add_legend = False
53-
else:
54-
if add_guide is True and funcname not in ("quiver", "streamplot"):
55-
raise ValueError("Cannot set add_guide when hue is None.")
56-
add_legend = False
57-
add_colorbar = False
58-
59-
if (add_guide or add_guide is None) and funcname == "quiver":
60-
add_quiverkey = True
61-
if hue:
62-
add_colorbar = True
63-
if not hue_style:
64-
hue_style = "continuous"
65-
elif hue_style != "continuous":
66-
raise ValueError(
67-
"hue_style must be 'continuous' or None for .plot.quiver or "
68-
".plot.streamplot"
69-
)
70-
else:
71-
add_quiverkey = False
72-
73-
if (add_guide or add_guide is None) and funcname == "streamplot":
74-
if hue:
75-
add_colorbar = True
76-
if not hue_style:
77-
hue_style = "continuous"
78-
elif hue_style != "continuous":
79-
raise ValueError(
80-
"hue_style must be 'continuous' or None for .plot.quiver or "
81-
".plot.streamplot"
82-
)
83-
84-
if hue_style is not None and hue_style not in ["discrete", "continuous"]:
85-
raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.")
86-
87-
if hue:
88-
hue_label = label_from_attrs(ds[hue])
89-
hue = ds[hue]
90-
else:
91-
hue_label = None
92-
hue = None
93-
94-
return {
95-
"add_colorbar": add_colorbar,
96-
"add_legend": add_legend,
97-
"add_quiverkey": add_quiverkey,
98-
"hue_label": hue_label,
99-
"hue_style": hue_style,
100-
"xlabel": label_from_attrs(ds[x]),
101-
"ylabel": label_from_attrs(ds[y]),
102-
"hue": hue,
103-
}
104-
10519

10620
def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None):
10721

@@ -134,47 +48,6 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None)
13448
return data
13549

13650

137-
# copied from seaborn
138-
def _parse_size(data, norm):
139-
140-
import matplotlib as mpl
141-
142-
if data is None:
143-
return None
144-
145-
data = data.values.flatten()
146-
147-
if not _is_numeric(data):
148-
levels = np.unique(data)
149-
numbers = np.arange(1, 1 + len(levels))[::-1]
150-
else:
151-
levels = numbers = np.sort(np.unique(data))
152-
153-
min_width, max_width = _MARKERSIZE_RANGE
154-
# width_range = min_width, max_width
155-
156-
if norm is None:
157-
norm = mpl.colors.Normalize()
158-
elif isinstance(norm, tuple):
159-
norm = mpl.colors.Normalize(*norm)
160-
elif not isinstance(norm, mpl.colors.Normalize):
161-
err = "``size_norm`` must be None, tuple, or Normalize object."
162-
raise ValueError(err)
163-
164-
norm.clip = True
165-
if not norm.scaled():
166-
norm(np.asarray(numbers))
167-
# limits = norm.vmin, norm.vmax
168-
169-
scl = norm(numbers)
170-
widths = np.asarray(min_width + scl * (max_width - min_width))
171-
if scl.mask.any():
172-
widths[scl.mask] = 0
173-
sizes = dict(zip(levels, widths))
174-
175-
return pd.Series(sizes)
176-
177-
17851
class _Dataset_PlotMethods:
17952
"""
18053
Enables use of xarray.plot functions as attributes on a Dataset.

xarray/plot/utils.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030

3131
ROBUST_PERCENTILE = 2.0
3232

33+
# copied from seaborn
34+
_MARKERSIZE_RANGE = np.array([18.0, 72.0])
35+
3336

3437
def import_matplotlib_pyplot():
3538
"""import pyplot"""
@@ -1141,3 +1144,127 @@ def _adjust_legend_subtitles(legend):
11411144
# The sutbtitles should have the same font size
11421145
# as normal legend titles:
11431146
text.set_size(font_size)
1147+
1148+
1149+
def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
1150+
dvars = set(ds.variables.keys())
1151+
error_msg = " must be one of ({:s})".format(", ".join(dvars))
1152+
1153+
if x not in dvars:
1154+
raise ValueError("x" + error_msg)
1155+
1156+
if y not in dvars:
1157+
raise ValueError("y" + error_msg)
1158+
1159+
if hue is not None and hue not in dvars:
1160+
raise ValueError("hue" + error_msg)
1161+
1162+
if hue:
1163+
hue_is_numeric = _is_numeric(ds[hue].values)
1164+
1165+
if hue_style is None:
1166+
hue_style = "continuous" if hue_is_numeric else "discrete"
1167+
1168+
if not hue_is_numeric and (hue_style == "continuous"):
1169+
raise ValueError(
1170+
f"Cannot create a colorbar for a non numeric coordinate: {hue}"
1171+
)
1172+
1173+
if add_guide is None or add_guide is True:
1174+
add_colorbar = True if hue_style == "continuous" else False
1175+
add_legend = True if hue_style == "discrete" else False
1176+
else:
1177+
add_colorbar = False
1178+
add_legend = False
1179+
else:
1180+
if add_guide is True and funcname not in ("quiver", "streamplot"):
1181+
raise ValueError("Cannot set add_guide when hue is None.")
1182+
add_legend = False
1183+
add_colorbar = False
1184+
1185+
if (add_guide or add_guide is None) and funcname == "quiver":
1186+
add_quiverkey = True
1187+
if hue:
1188+
add_colorbar = True
1189+
if not hue_style:
1190+
hue_style = "continuous"
1191+
elif hue_style != "continuous":
1192+
raise ValueError(
1193+
"hue_style must be 'continuous' or None for .plot.quiver or "
1194+
".plot.streamplot"
1195+
)
1196+
else:
1197+
add_quiverkey = False
1198+
1199+
if (add_guide or add_guide is None) and funcname == "streamplot":
1200+
if hue:
1201+
add_colorbar = True
1202+
if not hue_style:
1203+
hue_style = "continuous"
1204+
elif hue_style != "continuous":
1205+
raise ValueError(
1206+
"hue_style must be 'continuous' or None for .plot.quiver or "
1207+
".plot.streamplot"
1208+
)
1209+
1210+
if hue_style is not None and hue_style not in ["discrete", "continuous"]:
1211+
raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.")
1212+
1213+
if hue:
1214+
hue_label = label_from_attrs(ds[hue])
1215+
hue = ds[hue]
1216+
else:
1217+
hue_label = None
1218+
hue = None
1219+
1220+
return {
1221+
"add_colorbar": add_colorbar,
1222+
"add_legend": add_legend,
1223+
"add_quiverkey": add_quiverkey,
1224+
"hue_label": hue_label,
1225+
"hue_style": hue_style,
1226+
"xlabel": label_from_attrs(ds[x]),
1227+
"ylabel": label_from_attrs(ds[y]),
1228+
"hue": hue,
1229+
}
1230+
1231+
1232+
# copied from seaborn
1233+
def _parse_size(data, norm):
1234+
1235+
import matplotlib as mpl
1236+
1237+
if data is None:
1238+
return None
1239+
1240+
data = data.values.flatten()
1241+
1242+
if not _is_numeric(data):
1243+
levels = np.unique(data)
1244+
numbers = np.arange(1, 1 + len(levels))[::-1]
1245+
else:
1246+
levels = numbers = np.sort(np.unique(data))
1247+
1248+
min_width, max_width = _MARKERSIZE_RANGE
1249+
# width_range = min_width, max_width
1250+
1251+
if norm is None:
1252+
norm = mpl.colors.Normalize()
1253+
elif isinstance(norm, tuple):
1254+
norm = mpl.colors.Normalize(*norm)
1255+
elif not isinstance(norm, mpl.colors.Normalize):
1256+
err = "``size_norm`` must be None, tuple, or Normalize object."
1257+
raise ValueError(err)
1258+
1259+
norm.clip = True
1260+
if not norm.scaled():
1261+
norm(np.asarray(numbers))
1262+
# limits = norm.vmin, norm.vmax
1263+
1264+
scl = norm(numbers)
1265+
widths = np.asarray(min_width + scl * (max_width - min_width))
1266+
if scl.mask.any():
1267+
widths[scl.mask] = 0
1268+
sizes = dict(zip(levels, widths))
1269+
1270+
return pd.Series(sizes)

0 commit comments

Comments
 (0)