@@ -49,7 +49,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
4949 add_colorbar = False
5050 add_legend = False
5151 else :
52- if add_guide is True and funcname != "quiver" :
52+ if add_guide is True and funcname not in ( "quiver" , "streamplot" ) :
5353 raise ValueError ("Cannot set add_guide when hue is None." )
5454 add_legend = False
5555 add_colorbar = False
@@ -62,11 +62,23 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
6262 hue_style = "continuous"
6363 elif hue_style != "continuous" :
6464 raise ValueError (
65- "hue_style must be 'continuous' or None for .plot.quiver"
65+ "hue_style must be 'continuous' or None for .plot.quiver or "
66+ ".plot.streamplot"
6667 )
6768 else :
6869 add_quiverkey = False
6970
71+ if (add_guide or add_guide is None ) and funcname == "streamplot" :
72+ if hue :
73+ add_colorbar = True
74+ if not hue_style :
75+ hue_style = "continuous"
76+ elif hue_style != "continuous" :
77+ raise ValueError (
78+ "hue_style must be 'continuous' or None for .plot.quiver or "
79+ ".plot.streamplot"
80+ )
81+
7082 if hue_style is not None and hue_style not in ["discrete" , "continuous" ]:
7183 raise ValueError ("hue_style must be either None, 'discrete' or 'continuous'." )
7284
@@ -186,7 +198,7 @@ def _dsplot(plotfunc):
186198 x, y : str
187199 Variable names for x, y axis.
188200 u, v : str, optional
189- Variable names for quiver plots
201+ Variable names for quiver or streamplot plots
190202 hue: str, optional
191203 Variable by which to color scattered points
192204 hue_style: str, optional
@@ -338,8 +350,11 @@ def newplotfunc(
338350 else :
339351 cmap_params_subset = {}
340352
341- if (u is not None or v is not None ) and plotfunc .__name__ != "quiver" :
342- raise ValueError ("u, v are only allowed for quiver plots." )
353+ if (u is not None or v is not None ) and plotfunc .__name__ not in (
354+ "quiver" ,
355+ "streamplot" ,
356+ ):
357+ raise ValueError ("u, v are only allowed for quiver or streamplot plots." )
343358
344359 primitive = plotfunc (
345360 ds = ds ,
@@ -383,7 +398,7 @@ def newplotfunc(
383398 coordinates = "figure" ,
384399 )
385400
386- if plotfunc .__name__ == "quiver" :
401+ if plotfunc .__name__ in ( "quiver" , "streamplot" ) :
387402 title = ds [u ]._title_for_slice ()
388403 else :
389404 title = ds [x ]._title_for_slice ()
@@ -526,3 +541,54 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
526541 kwargs .setdefault ("pivot" , "middle" )
527542 hdl = ax .quiver (* args , ** kwargs , ** cmap_params )
528543 return hdl
544+
545+
546+ @_dsplot
547+ def streamplot (ds , x , y , ax , u , v , ** kwargs ):
548+ """ Quiver plot with Dataset variables."""
549+ import matplotlib as mpl
550+
551+ if x is None or y is None or u is None or v is None :
552+ raise ValueError ("Must specify x, y, u, v for streamplot plots." )
553+
554+ # Matplotlib's streamplot has strong restrictions on what x and y can be, so need to
555+ # get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so
556+ # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so
557+ # the dimension of y must be the first dimension. If x and y are both 2d, assume the
558+ # user has got them right already.
559+ if len (ds [x ].dims ) == 1 :
560+ xdim = ds [x ].dims [0 ]
561+ if len (ds [y ].dims ) == 1 :
562+ ydim = ds [y ].dims [0 ]
563+ if xdim is not None and ydim is None :
564+ ydim = set (ds [y ].dims ) - set ([xdim ])
565+ if ydim is not None and xdim is None :
566+ xdim = set (ds [x ].dims ) - set ([ydim ])
567+
568+ x , y , u , v = broadcast (ds [x ], ds [y ], ds [u ], ds [v ])
569+
570+ if xdim is not None and ydim is not None :
571+ # Need to ensure the arrays are transposed correctly
572+ x = x .transpose (ydim , xdim )
573+ y = y .transpose (ydim , xdim )
574+ u = u .transpose (ydim , xdim )
575+ v = v .transpose (ydim , xdim )
576+
577+ args = [x .values , y .values , u .values , v .values ]
578+ hue = kwargs .pop ("hue" )
579+ cmap_params = kwargs .pop ("cmap_params" )
580+
581+ if hue :
582+ kwargs ["color" ] = ds [hue ].values
583+
584+ # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
585+ if not cmap_params ["norm" ]:
586+ cmap_params ["norm" ] = mpl .colors .Normalize (
587+ cmap_params .pop ("vmin" ), cmap_params .pop ("vmax" )
588+ )
589+
590+ kwargs .pop ("hue_style" )
591+ hdl = ax .streamplot (* args , ** kwargs , ** cmap_params )
592+
593+ # Return .lines so colorbar creation works properly
594+ return hdl .lines
0 commit comments