Skip to content

Commit 5901b38

Browse files
authored
Merge pull request matplotlib#7363 from bcongdon/scatter-error-msg-fix
Add appropriate error on color size mismatch in `scatter`
2 parents 10f1522 + 49e7156 commit 5901b38

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3952,6 +3952,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39523952

39533953
# np.ma.ravel yields an ndarray, not a masked array,
39543954
# unless its argument is a masked array.
3955+
xy_shape = (np.shape(x), np.shape(y))
39553956
x = np.ma.ravel(x)
39563957
y = np.ma.ravel(y)
39573958
if x.size != y.size:
@@ -3974,7 +3975,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39743975
else:
39753976
try:
39763977
c_array = np.asanyarray(c, dtype=float)
3977-
if c_array.size == x.size:
3978+
if c_array.shape in xy_shape:
39783979
c = np.ma.ravel(c_array)
39793980
else:
39803981
# Wrong size; it must not be intended for mapping.
@@ -3984,7 +3985,14 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
39843985
c_array = None
39853986

39863987
if c_array is None:
3987-
colors = c # must be acceptable as PathCollection facecolors
3988+
try:
3989+
# must be acceptable as PathCollection facecolors
3990+
colors = mcolors.to_rgba_array(c)
3991+
except ValueError:
3992+
# c not acceptable as PathCollection facecolor
3993+
msg = ("c of shape {0} not acceptable as a color sequence "
3994+
"for x with size {1}, y with size {2}")
3995+
raise ValueError(msg.format(c.shape, x.size, y.size))
39883996
else:
39893997
colors = None # use cmap, norm after collection is created
39903998

lib/matplotlib/tests/test_axes.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4800,15 +4800,13 @@ def test_fillbetween_cycle():
48004800

48014801

48024802
@cleanup
4803-
def test_log_margins():
4804-
plt.rcParams['axes.autolimit_mode'] = 'data'
4803+
def test_color_length_mismatch():
4804+
N = 5
4805+
x, y = np.arange(N), np.arange(N)
4806+
colors = np.arange(N+1)
48054807
fig, ax = plt.subplots()
4806-
margin = 0.05
4807-
ax.set_xmargin(margin)
4808-
ax.semilogx([1, 10], [1, 10])
4809-
xlim0, xlim1 = ax.get_xlim()
4810-
transform = ax.xaxis.get_transform()
4811-
xlim0t, xlim1t = transform.transform([xlim0, xlim1])
4812-
x0t, x1t = transform.transform([1, 10])
4813-
delta = (x1t - x0t) * margin
4814-
assert_allclose([xlim0t + delta, xlim1t - delta], [x0t, x1t])
4808+
with pytest.raises(ValueError):
4809+
ax.scatter(x, y, c=colors)
4810+
c_rgb = (0.5, 0.5, 0.5)
4811+
ax.scatter(x, y, c=c_rgb)
4812+
ax.scatter(x, y, c=[c_rgb] * N)

0 commit comments

Comments
 (0)