Skip to content

Commit 23c44ee

Browse files
bors[bot]jthielen
andauthored
Merge #941
941: Fix infinite recursion with NumPy array/scalar raised to quantity power r=hgrecco a=jthielen After running the current master branch against MetPy's test suite, I found that #905 introduced an infinite recursion error when the base of the exponent was a NumPy array or scalar by naively deferring to standard exponentiation. There were also issues when both the base and power were Quantities. This PR resolves those issues and adds basic tests for those cases. - ~~Closes~~ (no associated issue) - [x] Executed ``black -t py36 . && isort -rc . && flake8`` with no errors - [x] The change is fully covered by automated unit tests - ~~Documented in docs/ as appropriate~~ - ~~Added an entry to the CHANGES file~~ (fixup to previous change: #905) Co-authored-by: Jon Thielen <github@jont.cc>
2 parents 4c3114c + a08cd4a commit 23c44ee

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

pint/numpy_func.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,10 @@ def _frexp(x, *args, **kwargs):
396396

397397
@implements("power", "ufunc")
398398
def _power(x1, x2):
399-
return x1 ** x2
399+
if _is_quantity(x1):
400+
return x1 ** x2
401+
else:
402+
return x2.__rpow__(x1)
400403

401404

402405
def _add_subtract_handle_non_quantity_zero(x1, x2):

pint/quantity.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,7 @@ def __pow__(self, other):
12581258
if other == 1:
12591259
return self
12601260
elif other == 0:
1261+
exponent = 0
12611262
units = UnitsContainer()
12621263
else:
12631264
if not self._is_multiplicative:
@@ -1267,13 +1268,15 @@ def __pow__(self, other):
12671268
raise OffsetUnitCalculusError(self._units)
12681269

12691270
if getattr(other, "dimensionless", False):
1270-
units = new_self._units ** other.to_root_units().magnitude
1271+
exponent = other.to_root_units().magnitude
1272+
units = new_self._units ** exponent
12711273
elif not getattr(other, "dimensionless", True):
12721274
raise DimensionalityError(other._units, "dimensionless")
12731275
else:
1274-
units = new_self._units ** other
1276+
exponent = _to_magnitude(other, self.force_ndarray)
1277+
units = new_self._units ** exponent
12751278

1276-
magnitude = new_self._magnitude ** _to_magnitude(other, self.force_ndarray)
1279+
magnitude = new_self._magnitude ** exponent
12771280
return self.__class__(magnitude, units)
12781281

12791282
@check_implemented

pint/testsuite/test_numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,14 @@ def test_power(self):
402402
q2_cp = copy.copy(q)
403403
self.assertRaises(DimensionalityError, op_, q_cp, q2_cp)
404404

405+
self.assertQuantityEqual(
406+
np.power(self.q, self.Q_(2)), self.Q_([[1, 4], [9, 16]], "m**2")
407+
)
408+
self.assertQuantityEqual(
409+
self.q ** self.Q_(2), self.Q_([[1, 4], [9, 16]], "m**2")
410+
)
411+
self.assertNDArrayEqual(arr ** self.Q_(2), np.array([0, 1, 4]))
412+
405413
@unittest.expectedFailure
406414
@helpers.requires_numpy()
407415
def test_exponentiation_array_exp_2(self):

0 commit comments

Comments
 (0)