Skip to content

Commit

Permalink
Merge #962
Browse files Browse the repository at this point in the history
962: Improved wraps and check r=hgrecco a=hgrecco

- [x] Closes #711, #723
- [x] Executed ``black -t py36 . && isort -rc . && flake8`` with no errors
- [x] The change is fully covered by automated unit tests
- [x] Documented in docs/ as appropriate
- [x] Added an entry to the CHANGES file


Co-authored-by: Hernan <hernan.grecco@gmail.com>
  • Loading branch information
bors[bot] and hgrecco authored Dec 28, 2019
2 parents 68a3e91 + 8c82eae commit f36cc74
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 50 deletions.
7 changes: 7 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ Pint Changelog
0.10 (unreleased)
-----------------

- Improvements to wraps and check:
- fail upon decoration (not execution) by checking wrapped function signature against
wraps/check arguments.
(might BREAK test code)
- wraps only accepts strings and Units (not quantities) to avoid confusion with magnitude.
(might BREAK code not conforming to documentation)
- when strict=True, strings that can be parsed to quantities are accepted as arguments.
- Add revolutions per second (rps)
- Improved compatbility for upcast types like xarray's DataArray or Dataset, to which
Pint Quantities now fully defer for arithmetic and NumPy operations. A collection of
Expand Down
115 changes: 76 additions & 39 deletions pint/registry_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,20 @@ def _converter(ureg, values, strict):
)
else:
if strict:
raise ValueError(
"A wrapped function using strict=True requires "
"quantity for all arguments with not None units. "
"(error found for {}, {})".format(
args_as_uc[ndx][0], new_values[ndx]
if isinstance(values[ndx], str):
# if the value is a string, we try to parse it
tmp_value = ureg.parse_expression(values[ndx])
new_values[ndx] = ureg._convert(
tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0]
)
else:
raise ValueError(
"A wrapped function using strict=True requires "
"quantity or a string for all arguments with not None units. "
"(error found for {}, {})".format(
args_as_uc[ndx][0], new_values[ndx]
)
)
)

return new_values, values_by_name

Expand Down Expand Up @@ -179,41 +186,68 @@ def wraps(ureg, ret, args, strict=True):
The value returned by the wrapped function will be converted to the units
specified in `ret`.
Use None to skip argument conversion.
Set strict to False, to accept also numerical values.
Parameters
----------
ureg :
ureg : UnitRegistry
a UnitRegistry instance.
ret :
output units.
args :
iterable of input units.
ret : iterable of str or iterable of Unit
Units of each of the return values. Use `None` to skip argument conversion.
args : iterable of str or iterable of Unit
Units of each of the input arguments. Use `None` to skip argument conversion.
strict : bool
Indicates that only quantities are accepted. (Default value = True)
Returns
-------
callable
the wrapped function.
the wrapper function.
Raises
------
TypeError
if the number of given arguments does not match the number of function parameters.
if the any of the provided arguments is not a unit a string or Quantity
"""

if not isinstance(args, (list, tuple)):
args = (args,)

for arg in args:
if arg is not None and not isinstance(arg, (ureg.Unit, str)):
raise TypeError(
"wraps arguments must by of type str or Unit, not %s (%s)"
% (type(arg), arg)
)

converter = _parse_wrap_args(args)

if isinstance(ret, (list, tuple)):
container, ret = (
True,
ret.__class__([_to_units_container(arg, ureg) for arg in ret]),
)
is_ret_container = isinstance(ret, (list, tuple))
if is_ret_container:
for arg in ret:
if arg is not None and not isinstance(arg, (ureg.Unit, str)):
raise TypeError(
"wraps 'ret' argument must by of type str or Unit, not %s (%s)"
% (type(arg), arg)
)
ret = ret.__class__([_to_units_container(arg, ureg) for arg in ret])
else:
container, ret = False, _to_units_container(ret, ureg)
if ret is not None and not isinstance(ret, (ureg.Unit, str)):
raise TypeError(
"wraps 'ret' argument must by of type str or Unit, not %s (%s)"
% (type(ret), ret)
)
ret = _to_units_container(ret, ureg)

def decorator(func):

count_params = len(signature(func).parameters)
if len(args) != count_params:
raise TypeError(
"%s takes %i parameters, but %i units were passed"
% (func.__name__, count_params, len(args))
)

assigned = tuple(
attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr)
)
Expand All @@ -232,7 +266,7 @@ def wrapper(*values, **kw):

result = func(*new_values, **kw)

if container:
if is_ret_container:
out_units = (
_replace_units(r, values_by_name) if is_ref else r
for (r, is_ref) in ret
Expand All @@ -258,35 +292,42 @@ def check(ureg, *args):
"""Decorator to for quantity type checking for function inputs.
Use it to ensure that the decorated function input parameters match
the expected type of pint quantity.
the expected dimension of pint quantity.
Use None to skip argument checking.
The wrapper function raises:
- `pint.DimensionalityError` if an argument doesn't match the required dimensions.
Parameters
----------
ureg :
ureg : UnitRegistry
a UnitRegistry instance.
args :
iterable of input units.
*args :
*args : iterable of str or iterable of UnitContainer
Dimensions of each of the input arguments. Use `None` to skip argument conversion.
Returns
-------
type
callable
the wrapped function.
Raises
------
pint.DimensionalityError
if the parameters don't match dimensions
TypeError
if the number of given dimensions does not match the number of function parameters.
ValueError
if the any of the provided dimensions cannot be parsed as a dimension.
"""
dimensions = [
ureg.get_dimensionality(dim) if dim is not None else None for dim in args
]

def decorator(func):

count_params = len(signature(func).parameters)
if len(dimensions) != count_params:
raise TypeError(
"%s takes %i parameters, but %i dimensions were passed"
% (func.__name__, count_params, len(dimensions))
)

assigned = tuple(
attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr)
)
Expand All @@ -297,11 +338,7 @@ def decorator(func):
@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*args, **kwargs):
list_args, empty = _apply_defaults(func, args, kwargs)
if len(dimensions) > len(list_args):
raise TypeError(
"%s takes %i parameters, but %i dimensions were passed"
% (func.__name__, len(list_args), len(dimensions))
)

for dim, value in zip(dimensions, list_args):

if dim is None:
Expand Down
26 changes: 15 additions & 11 deletions pint/testsuite/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def func(x):

ureg = self.ureg

self.assertRaises(TypeError, ureg.wraps, (3 * ureg.meter, [None]))
self.assertRaises(TypeError, ureg.wraps, (None, [3 * ureg.meter]))

f0 = ureg.wraps(None, [None])(func)
self.assertEqual(f0(3.0), 3.0)

Expand All @@ -451,6 +454,16 @@ def func(x):
self.assertEqual(f1b(3.0 * ureg.meter), 3.0)
self.assertRaises(DimensionalityError, f1b, 3 * ureg.second)

f1c = ureg.wraps("meter", [ureg.meter])(func)
self.assertEqual(f1c(3.0 * ureg.centimeter), 0.03 * ureg.meter)
self.assertEqual(f1c(3.0 * ureg.meter), 3.0 * ureg.meter)
self.assertRaises(DimensionalityError, f1c, 3 * ureg.second)

f1d = ureg.wraps(ureg.meter, [ureg.meter])(func)
self.assertEqual(f1d(3.0 * ureg.centimeter), 0.03 * ureg.meter)
self.assertEqual(f1d(3.0 * ureg.meter), 3.0 * ureg.meter)
self.assertRaises(DimensionalityError, f1d, 3 * ureg.second)

f1 = ureg.wraps(None, "meter")(func)
self.assertRaises(ValueError, f1, 3.0)
self.assertEqual(f1(3.0 * ureg.centimeter), 0.03)
Expand Down Expand Up @@ -565,17 +578,8 @@ def gfunc(x, y):
1 * ureg.meter / ureg.second ** 2,
)

g2 = ureg.check("[speed]")(gfunc)
self.assertRaises(DimensionalityError, g2, 3.0, 1)
self.assertRaises(TypeError, g2, 2 * ureg.parsec)
self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec, 1.0)
self.assertEqual(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour)

g3 = ureg.check("[speed]", "[time]", "[mass]")(gfunc)
self.assertRaises(TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom)
self.assertRaises(
TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom, 1 * ureg.kilogram
)
self.assertRaises(TypeError, ureg.check("[speed]"), gfunc)
self.assertRaises(TypeError, ureg.check("[speed]", "[time]", "[mass]"), gfunc)

def test_to_ref_vs_to(self):
self.ureg.autoconvert_offset_to_baseunit = True
Expand Down

0 comments on commit f36cc74

Please sign in to comment.