1111import  numpy  as  np 
1212import  pandas  as  pd 
1313
14+ from  ..core .alignment  import  broadcast 
1415from  .facetgrid  import  _easy_facetgrid 
1516from  .utils  import  (
1617    _add_colorbar ,
@@ -666,17 +667,6 @@ def newplotfunc(
666667            darray = darray , x = x , y = y , imshow = imshow_rgb , rgb = rgb 
667668        )
668669
669-         # better to pass the ndarrays directly to plotting functions 
670-         xval  =  darray [xlab ].values 
671-         yval  =  darray [ylab ].values 
672- 
673-         # check if we need to broadcast one dimension 
674-         if  xval .ndim  <  yval .ndim :
675-             xval  =  np .broadcast_to (xval , yval .shape )
676- 
677-         if  yval .ndim  <  xval .ndim :
678-             yval  =  np .broadcast_to (yval , xval .shape )
679- 
680670        # May need to transpose for correct x, y labels 
681671        # xlab may be the name of a coord, we have to check for dim names 
682672        if  imshow_rgb :
@@ -690,8 +680,17 @@ def newplotfunc(
690680        elif  darray [xlab ].dims [- 1 ] ==  darray .dims [0 ]:
691681            darray  =  darray .transpose (transpose_coords = True )
692682
693-         # Pass the data as a masked ndarray too 
694-         zval  =  darray .to_masked_array (copy = False )
683+         # better to pass the ndarrays directly to plotting functions 
684+         # Pass the data as a masked ndarray 
685+         if  darray [xlab ].ndim  ==  1  and  darray [ylab ].ndim  ==  1 :
686+             xval  =  darray [xlab ].values 
687+             yval  =  darray [ylab ].values 
688+             zval  =  darray .to_masked_array (copy = False )
689+         else :
690+             xval , yval , zval  =  map (
691+                 lambda  x : x .values , broadcast (darray [xlab ], darray [ylab ], darray )
692+             )
693+             zval  =  np .ma .masked_array (zval , mask = pd .isnull (zval ), copy = False )
695694
696695        # Replace pd.Intervals if contained in xval or yval. 
697696        xplt , xlab_extra  =  _resolve_intervals_2dplot (xval , plotfunc .__name__ )
0 commit comments