Skip to content

Commit 8d76ceb

Browse files
committed
Merge pull request numpy#3187 from ericfode/float16pow
Float16pow
2 parents 3fd6b62 + 8b42156 commit 8d76ceb

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

numpy/core/src/scalarmathmodule.c.src

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -494,16 +494,25 @@ half_ctype_remainder(npy_half a, npy_half b, npy_half *out) {
494494
/**end repeat**/
495495

496496
/**begin repeat
497-
* #name = half, float, double, longdouble#
498-
* #type = npy_half, npy_float, npy_double, npy_longdouble#
497+
* #name = float, double, longdouble#
498+
* #type = npy_float, npy_double, npy_longdouble#
499499
*/
500500
static npy_@name@ (*_basic_@name@_pow)(@type@ a, @type@ b);
501501

502502
static void
503-
@name@_ctype_power(@type@ a, @type@ b, @type@ *out) {
503+
@name@_ctype_power(@type@ a, @type@ b, @type@ *out)
504+
{
504505
*out = _basic_@name@_pow(a, b);
505506
}
506507
/**end repeat**/
508+
static void
509+
half_ctype_power(npy_half a, npy_half b, npy_half *out)
510+
{
511+
const npy_float af = npy_half_to_float(a);
512+
const npy_float bf = npy_half_to_float(b);
513+
const npy_float outf = _basic_float_pow(af,bf);
514+
*out = npy_float_to_half(outf);
515+
}
507516

508517
/**begin repeat
509518
* #name = byte, ubyte, short, ushort, int, uint,
@@ -1130,7 +1139,6 @@ static PyObject *
11301139
int first;
11311140

11321141
@type@ out = @zero@;
1133-
11341142
switch(_@name@_convert2_to_ctypes(a, &arg1, b, &arg2)) {
11351143
case 0:
11361144
break;
@@ -1724,7 +1732,6 @@ get_functions(void)
17241732
i += 3;
17251733
j++;
17261734
}
1727-
_basic_half_pow = funcdata[j - 1];
17281735
_basic_float_pow = funcdata[j];
17291736
_basic_double_pow = funcdata[j + 1];
17301737
_basic_longdouble_pow = funcdata[j + 2];

numpy/core/tests/test_scalarmath.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_type_create(self, level=1):
4444

4545
class TestPower(TestCase):
4646
def test_small_types(self):
47-
for t in [np.int8, np.int16]:
47+
for t in [np.int8, np.int16, np.float16]:
4848
a = t(3)
4949
b = a ** 4
5050
assert_(b == 81, "error with %r: got %r" % (t,b))
@@ -58,7 +58,21 @@ def test_large_types(self):
5858
assert_(b == 6765201, msg)
5959
else:
6060
assert_almost_equal(b, 6765201, err_msg=msg)
61-
61+
def test_mixed_types(self):
62+
typelist = [np.int8,np.int16,np.float16,
63+
np.float32,np.float64,np.int8,
64+
np.int16,np.int32,np.int64]
65+
for t1 in typelist:
66+
for t2 in typelist:
67+
a = t1(3)
68+
b = t2(2)
69+
result = a**b
70+
msg = ("error with %r and %r:"
71+
"got %r, expected %r") % (t1, t2, result, 9)
72+
if np.issubdtype(np.dtype(result), np.integer):
73+
assert_(result == 9, msg)
74+
else:
75+
assert_almost_equal(result, 9, err_msg=msg)
6276

6377
class TestComplexDivision(TestCase):
6478
def test_zero_division(self):

0 commit comments

Comments
 (0)