Skip to content

Commit

Permalink
Move cbook._check_shape() to _api.check_shape()
Browse files Browse the repository at this point in the history
  • Loading branch information
timhoffm committed Sep 17, 2020
1 parent f0be6b1 commit 9f9ea54
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 58 deletions.
40 changes: 40 additions & 0 deletions lib/matplotlib/_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools


def check_in_list(_values, *, _print_supported_values=True, **kwargs):
Expand Down Expand Up @@ -31,3 +32,42 @@ def check_in_list(_values, *, _print_supported_values=True, **kwargs):
f"supported values are {', '.join(map(repr, values))}")
else:
raise ValueError(f"{val!r} is not a valid value for {key}")


def check_shape(_shape, **kwargs):
"""
For each *key, value* pair in *kwargs*, check that *value* has the shape
*_shape*, if not, raise an appropriate ValueError.
*None* in the shape is treated as a "free" size that can have any length.
e.g. (None, 2) -> (N, 2)
The values checked must be numpy arrays.
Examples
--------
To check for (N, 2) shaped arrays
>>> _api.check_shape((None, 2), arg=arg, other_arg=other_arg)
"""
target_shape = _shape
for k, v in kwargs.items():
data_shape = v.shape

if len(target_shape) != len(data_shape) or any(
t not in [s, None]
for t, s in zip(target_shape, data_shape)
):
dim_labels = iter(itertools.chain(
'MNLIJKLH',
(f"D{i}" for i in itertools.count())))
text_shape = ", ".join((str(n)
if n is not None
else next(dim_labels)
for n in target_shape))

raise ValueError(
f"{k!r} must be {len(target_shape)}D "
f"with shape ({text_shape}). "
f"Your input has shape {v.shape}."
)
39 changes: 0 additions & 39 deletions lib/matplotlib/cbook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,45 +2281,6 @@ def type_name(tp):
type_name(type(v))))


def _check_shape(_shape, **kwargs):
"""
For each *key, value* pair in *kwargs*, check that *value* has the shape
*_shape*, if not, raise an appropriate ValueError.
*None* in the shape is treated as a "free" size that can have any length.
e.g. (None, 2) -> (N, 2)
The values checked must be numpy arrays.
Examples
--------
To check for (N, 2) shaped arrays
>>> _api.check_in_list((None, 2), arg=arg, other_arg=other_arg)
"""
target_shape = _shape
for k, v in kwargs.items():
data_shape = v.shape

if len(target_shape) != len(data_shape) or any(
t not in [s, None]
for t, s in zip(target_shape, data_shape)
):
dim_labels = iter(itertools.chain(
'MNLIJKLH',
(f"D{i}" for i in itertools.count())))
text_shape = ", ".join((str(n)
if n is not None
else next(dim_labels)
for n in target_shape))

raise ValueError(
f"{k!r} must be {len(target_shape)}D "
f"with shape ({text_shape}). "
f"Your input has shape {v.shape}."
)


def _check_getitem(_mapping, **kwargs):
"""
*kwargs* must consist of a single *key, value* pair. If *key* is in
Expand Down
4 changes: 2 additions & 2 deletions lib/matplotlib/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np

import matplotlib as mpl
from . import _path, cbook
from . import _api, _path, cbook
from .cbook import _to_unmasked_float_array, simple_linear_interpolation
from .bezier import BezierSegment

Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(self, vertices, codes=None, _interpolation_steps=1,
and codes as read-only arrays.
"""
vertices = _to_unmasked_float_array(vertices)
cbook._check_shape((None, 2), vertices=vertices)
_api.check_shape((None, 2), vertices=vertices)

if codes is not None:
codes = np.asarray(codes, self.code_type)
Expand Down
21 changes: 21 additions & 0 deletions lib/matplotlib/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import re

import numpy as np
import pytest

from matplotlib import _api


@pytest.mark.parametrize('target,test_shape',
[((None, ), (1, 3)),
((None, 3), (1,)),
((None, 3), (1, 2)),
((1, 5), (1, 9)),
((None, 2, None), (1, 3, 1))
])
def test_check_shape(target, test_shape):
error_pattern = (f"^'aardvark' must be {len(target)}D.*" +
re.escape(f'has shape {test_shape}'))
data = np.zeros(test_shape)
with pytest.raises(ValueError, match=error_pattern):
_api.check_shape(target, aardvark=data)
17 changes: 0 additions & 17 deletions lib/matplotlib/tests/test_cbook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
import pickle
import re

from weakref import ref
from unittest.mock import patch, Mock
Expand Down Expand Up @@ -675,22 +674,6 @@ def divisors(n):
check(x, rstride=rstride, cstride=cstride)


@pytest.mark.parametrize('target,test_shape',
[((None, ), (1, 3)),
((None, 3), (1,)),
((None, 3), (1, 2)),
((1, 5), (1, 9)),
((None, 2, None), (1, 3, 1))
])
def test_check_shape(target, test_shape):
error_pattern = (f"^'aardvark' must be {len(target)}D.*" +
re.escape(f'has shape {test_shape}'))
data = np.zeros(test_shape)
with pytest.raises(ValueError,
match=error_pattern):
cbook._check_shape(target, aardvark=data)


def test_setattr_cm():
class A:

Expand Down

0 comments on commit 9f9ea54

Please sign in to comment.