Skip to content

Commit 5f09deb

Browse files
authored
Properly support user-provided norm. (#2443)
* Properly support user-provided norm. Fixes #2381 * remove top level mpl import. * More accurate error message. * whats-new fixes.
1 parent cf1e6c7 commit 5f09deb

File tree

4 files changed

+87
-21
lines changed

4 files changed

+87
-21
lines changed

doc/whats-new.rst

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ Breaking changes
4040

4141
Documentation
4242
~~~~~~~~~~~~~
43+
4344
Enhancements
4445
~~~~~~~~~~~~
4546

4647
- Added support for Python 3.7. (:issue:`2271`).
4748
By `Joe Hamman <https://github.com/jhamman>`_.
48-
4949
- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a
50-
CFTimeIndex by a specified frequency. (:issue:`2244`). By `Spencer Clark
51-
<https://github.com/spencerkclark>`_.
50+
CFTimeIndex by a specified frequency. (:issue:`2244`).
51+
By `Spencer Clark <https://github.com/spencerkclark>`_.
5252
- Added support for using ``cftime.datetime`` coordinates with
5353
:py:meth:`~xarray.DataArray.differentiate`,
5454
:py:meth:`~xarray.Dataset.differentiate`,
@@ -60,11 +60,14 @@ Bug fixes
6060
~~~~~~~~~
6161

6262
- Addition and subtraction operators used with a CFTimeIndex now preserve the
63-
index's type. (:issue:`2244`). By `Spencer Clark <https://github.com/spencerkclark>`_.
63+
index's type. (:issue:`2244`).
64+
By `Spencer Clark <https://github.com/spencerkclark>`_.
6465
- ``xarray.DataArray.roll`` correctly handles multidimensional arrays.
6566
(:issue:`2445`)
6667
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
67-
68+
- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override
69+
the norm's ``vmin`` and ``vmax``. (:issue:`2381`)
70+
By `Deepak Cherian <https://github.com/dcherian>`_.
6871
- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument.
6972
(:issue:`2240`)
7073
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

xarray/plot/plot.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,9 @@ def _plot2d(plotfunc):
562562
Adds colorbar to axis
563563
add_labels : Boolean, optional
564564
Use xarray metadata to label axes
565+
norm : ``matplotlib.colors.Normalize`` instance, optional
566+
If the ``norm`` has vmin or vmax specified, the corresponding kwarg
567+
must be None.
565568
vmin, vmax : floats, optional
566569
Values to anchor the colormap, otherwise they are inferred from the
567570
data and other keyword arguments. When a diverging dataset is inferred,
@@ -630,7 +633,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
630633
levels=None, infer_intervals=None, colors=None,
631634
subplot_kws=None, cbar_ax=None, cbar_kwargs=None,
632635
xscale=None, yscale=None, xticks=None, yticks=None,
633-
xlim=None, ylim=None, **kwargs):
636+
xlim=None, ylim=None, norm=None, **kwargs):
634637
# All 2d plots in xarray share this function signature.
635638
# Method signature below should be consistent.
636639

@@ -727,6 +730,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
727730
'extend': extend,
728731
'levels': levels,
729732
'filled': plotfunc.__name__ != 'contour',
733+
'norm': norm,
730734
}
731735

732736
cmap_params = _determine_cmap_params(**cmap_kwargs)
@@ -746,9 +750,6 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
746750
if 'pcolormesh' == plotfunc.__name__:
747751
kwargs['infer_intervals'] = infer_intervals
748752

749-
# This allows the user to pass in a custom norm coming via kwargs
750-
kwargs.setdefault('norm', cmap_params['norm'])
751-
752753
if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring):
753754
# forbid usage of mpl strings
754755
raise ValueError("plt.imshow's `aspect` kwarg is not available "
@@ -758,6 +759,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
758759
primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'],
759760
vmin=cmap_params['vmin'],
760761
vmax=cmap_params['vmax'],
762+
norm=cmap_params['norm'],
761763
**kwargs)
762764

763765
# Label the plot with metadata
@@ -809,7 +811,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None,
809811
levels=None, infer_intervals=None, subplot_kws=None,
810812
cbar_ax=None, cbar_kwargs=None,
811813
xscale=None, yscale=None, xticks=None, yticks=None,
812-
xlim=None, ylim=None, **kwargs):
814+
xlim=None, ylim=None, norm=None, **kwargs):
813815
"""
814816
The method should have the same signature as the function.
815817

xarray/plot/utils.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
172172
# vlim might be computed below
173173
vlim = None
174174

175+
# save state; needed later
176+
vmin_was_none = vmin is None
177+
vmax_was_none = vmax is None
178+
175179
if vmin is None:
176180
if robust:
177181
vmin = np.percentile(calc_data, ROBUST_PERCENTILE)
@@ -204,6 +208,28 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
204208
vmin += center
205209
vmax += center
206210

211+
# now check norm and harmonize with vmin, vmax
212+
if norm is not None:
213+
if norm.vmin is None:
214+
norm.vmin = vmin
215+
else:
216+
if not vmin_was_none and vmin != norm.vmin:
217+
raise ValueError('Cannot supply vmin and a norm'
218+
+ ' with a different vmin.')
219+
vmin = norm.vmin
220+
221+
if norm.vmax is None:
222+
norm.vmax = vmax
223+
else:
224+
if not vmax_was_none and vmax != norm.vmax:
225+
raise ValueError('Cannot supply vmax and a norm'
226+
+ ' with a different vmax.')
227+
vmax = norm.vmax
228+
229+
# if BoundaryNorm, then set levels
230+
if isinstance(norm, mpl.colors.BoundaryNorm):
231+
levels = norm.boundaries
232+
207233
# Choose default colormaps if not provided
208234
if cmap is None:
209235
if divergent:
@@ -212,7 +238,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
212238
cmap = OPTIONS['cmap_sequential']
213239

214240
# Handle discrete levels
215-
if levels is not None:
241+
if levels is not None and norm is None:
216242
if is_scalar(levels):
217243
if user_minmax:
218244
levels = np.linspace(vmin, vmax, levels)
@@ -227,8 +253,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
227253
if extend is None:
228254
extend = _determine_extend(calc_data, vmin, vmax)
229255

230-
if levels is not None:
231-
cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled)
256+
if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm):
257+
cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled)
258+
norm = newnorm if norm is None else norm
232259

233260
return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
234261
levels=levels, norm=norm)

xarray/tests/test_plot.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,26 @@ def test_divergentcontrol(self):
628628
assert cmap_params['vmax'] == 0.6
629629
assert cmap_params['cmap'] == "viridis"
630630

631+
def test_norm_sets_vmin_vmax(self):
632+
vmin = self.data.min()
633+
vmax = self.data.max()
634+
635+
for norm, extend in zip([mpl.colors.LogNorm(),
636+
mpl.colors.LogNorm(vmin + 1, vmax - 1),
637+
mpl.colors.LogNorm(None, vmax - 1),
638+
mpl.colors.LogNorm(vmin + 1, None)],
639+
['neither', 'both', 'max', 'min']):
640+
641+
test_min = vmin if norm.vmin is None else norm.vmin
642+
test_max = vmax if norm.vmax is None else norm.vmax
643+
644+
cmap_params = _determine_cmap_params(self.data, norm=norm)
645+
646+
assert cmap_params['vmin'] == test_min
647+
assert cmap_params['vmax'] == test_max
648+
assert cmap_params['extend'] == extend
649+
assert cmap_params['norm'] == norm
650+
631651

632652
@requires_matplotlib
633653
class TestDiscreteColorMap(object):
@@ -665,10 +685,10 @@ def test_build_discrete_cmap(self):
665685

666686
@pytest.mark.slow
667687
def test_discrete_colormap_list_of_levels(self):
668-
for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both',
669-
[2, 5, 10, 11]),
670-
('neither', [0, 5, 10, 15]), ('min',
671-
[2, 5, 10, 15])]:
688+
for extend, levels in [('max', [-1, 2, 4, 8, 10]),
689+
('both', [2, 5, 10, 11]),
690+
('neither', [0, 5, 10, 15]),
691+
('min', [2, 5, 10, 15])]:
672692
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
673693
primitive = getattr(self.darray.plot, kind)(levels=levels)
674694
assert_array_equal(levels, primitive.norm.boundaries)
@@ -682,10 +702,10 @@ def test_discrete_colormap_list_of_levels(self):
682702

683703
@pytest.mark.slow
684704
def test_discrete_colormap_int_levels(self):
685-
for extend, levels, vmin, vmax in [('neither', 7, None,
686-
None), ('neither', 7, None, 20),
687-
('both', 7, 4, 8), ('min', 10, 4,
688-
15)]:
705+
for extend, levels, vmin, vmax in [('neither', 7, None, None),
706+
('neither', 7, None, 20),
707+
('both', 7, 4, 8),
708+
('min', 10, 4, 15)]:
689709
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
690710
primitive = getattr(self.darray.plot, kind)(
691711
levels=levels, vmin=vmin, vmax=vmax)
@@ -711,6 +731,11 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self):
711731
assert primitive.norm.vmax == max(levels)
712732
assert primitive.norm.vmin == min(levels)
713733

734+
def test_discrete_colormap_provided_boundary_norm(self):
735+
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
736+
primitive = self.darray.plot.contourf(norm=norm)
737+
np.testing.assert_allclose(primitive.levels, norm.boundaries)
738+
714739

715740
class Common2dMixin(object):
716741
"""
@@ -1085,6 +1110,15 @@ def test_cmap_and_color_both(self):
10851110
with pytest.raises(ValueError):
10861111
self.plotmethod(colors='k', cmap='RdBu')
10871112

1113+
def test_colormap_error_norm_and_vmin_vmax(self):
1114+
norm = mpl.colors.LogNorm(0.1, 1e1)
1115+
1116+
with pytest.raises(ValueError):
1117+
self.darray.plot(norm=norm, vmin=2)
1118+
1119+
with pytest.raises(ValueError):
1120+
self.darray.plot(norm=norm, vmax=2)
1121+
10881122

10891123
@pytest.mark.slow
10901124
class TestContourf(Common2dMixin, PlotTestCase):

0 commit comments

Comments
 (0)