Skip to content

Commit

Permalink
implement dpnp.logsumexp and dpnp.reduce_hypot (#1648)
Browse files Browse the repository at this point in the history
* implement logsumexp and reduce_hypot

* fix pre-commit

* address comments
  • Loading branch information
vtavana authored Dec 22, 2023
1 parent 9e8323e commit 20513fb
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 14 deletions.
2 changes: 2 additions & 0 deletions doc/reference/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Trigonometric functions
dpnp.unwrap
dpnp.deg2rad
dpnp.rad2deg
dpnp.reduce_hypot


Hyperbolic functions
Expand Down Expand Up @@ -94,6 +95,7 @@ Exponents and logarithms
dpnp.log1p
dpnp.logaddexp
dpnp.logaddexp2
dpnp.logsumexp


Other special functions
Expand Down
157 changes: 157 additions & 0 deletions dpnp/dpnp_iface_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@
"""


import dpctl.tensor as dpt
import numpy

import dpnp
from dpnp.dpnp_algo import *
from dpnp.dpnp_array import dpnp_array
from dpnp.dpnp_utils import *

from .dpnp_algo.dpnp_elementwise_common import (
Expand Down Expand Up @@ -98,9 +100,11 @@
"log1p",
"log2",
"logaddexp",
"logsumexp",
"rad2deg",
"radians",
"reciprocal",
"reduce_hypot",
"rsqrt",
"sin",
"sinh",
Expand Down Expand Up @@ -989,6 +993,10 @@ def hypot(
Otherwise the function will be executed sequentially on CPU.
Input array data types are limited by supported real-valued data types.
See Also
--------
:obj:`dpnp.reduce_hypot` : The square root of the sum of squares of elements in the input array.
Examples
--------
>>> import dpnp as np
Expand Down Expand Up @@ -1303,6 +1311,7 @@ def logaddexp(
--------
:obj:`dpnp.log` : Natural logarithm, element-wise.
:obj:`dpnp.exp` : Exponential, element-wise.
:obj:`dpnp.logsumdexp` : Logarithm of the sum of exponentials of elements in the input array.
Examples
--------
Expand Down Expand Up @@ -1331,6 +1340,81 @@ def logaddexp(
)


def logsumexp(x, axis=None, out=None, dtype=None, keepdims=False):
"""
Calculates the logarithm of the sum of exponentials of elements in the input array.
Parameters
----------
x : {dpnp_array, usm_ndarray}
Input array, expected to have a real-valued data type.
axis : int or tuple of ints, optional
Axis or axes along which values must be computed. If a tuple
of unique integers, values are computed over multiple axes.
If ``None``, the result is computed over the entire array.
Default: ``None``.
out : {dpnp_array, usm_ndarray}, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
dtype : data type, optional
Data type of the returned array. If ``None``, the default data
type is inferred from the "kind" of the input array data type.
* If `x` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If `x` has a boolean or integral data type, the returned array
will have the default floating point data type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data type,
an error is raised.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array elements are cast to the
specified data type before computing the result. Default: ``None``.
keepdims : bool
If ``True``, the reduced axes (dimensions) are included in the result
as singleton dimensions, so that the returned array remains
compatible with the input arrays according to Array Broadcasting
rules. Otherwise, if ``False``, the reduced axes are not included in
the returned array. Default: ``False``.
Returns
-------
out : dpnp.ndarray
An array containing the results. If the result was computed over
the entire array, a zero-dimensional array is returned. The returned
array has the data type as described in the `dtype` parameter
description above.
Note
----
This function is equivalent of `numpy.logaddexp.reduce`.
See Also
--------
:obj:`dpnp.log` : Natural logarithm, element-wise.
:obj:`dpnp.exp` : Exponential, element-wise.
:obj:`dpnp.logaddexp` : Logarithm of the sum of exponentiations of the inputs, element-wise.
Examples
--------
>>> import dpnp as np
>>> a = np.ones(10)
>>> np.logsumexp(a)
array(3.30258509)
>>> np.log(np.sum(np.exp(a)))
array(3.30258509)
"""

dpt_array = dpnp.get_usm_ndarray(x)
result = dpnp_array._create_from_usm_ndarray(
dpt.logsumexp(dpt_array, axis=axis, dtype=dtype, keepdims=keepdims)
)

return dpnp.get_result_array(result, out, casting="same_kind")


def reciprocal(x1, **kwargs):
"""
Return the reciprocal of the argument, element-wise.
Expand Down Expand Up @@ -1363,6 +1447,79 @@ def reciprocal(x1, **kwargs):
return call_origin(numpy.reciprocal, x1, **kwargs)


def reduce_hypot(x, axis=None, out=None, dtype=None, keepdims=False):
"""
Calculates the square root of the sum of squares of elements in the input array.
Parameters
----------
x : {dpnp_array, usm_ndarray}
Input array, expected to have a real-valued data type.
axis : int or tuple of ints, optional
Axis or axes along which values must be computed. If a tuple
of unique integers, values are computed over multiple axes.
If ``None``, the result is computed over the entire array.
Default: ``None``.
out : {dpnp_array, usm_ndarray}, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
dtype : data type, optional
Data type of the returned array. If ``None``, the default data
type is inferred from the "kind" of the input array data type.
* If `x` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If `x` has a boolean or integral data type, the returned array
will have the default floating point data type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data type,
an error is raised.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array elements are cast to the
specified data type before computing the result. Default: ``None``.
keepdims : bool
If ``True``, the reduced axes (dimensions) are included in the result
as singleton dimensions, so that the returned array remains
compatible with the input arrays according to Array Broadcasting
rules. Otherwise, if ``False``, the reduced axes are not included in
the returned array. Default: ``False``.
Returns
-------
out : dpnp.ndarray
An array containing the results. If the result was computed over
the entire array, a zero-dimensional array is returned. The returned
array has the data type as described in the `dtype` parameter
description above.
Note
----
This function is equivalent of `numpy.hypot.reduce`.
See Also
--------
:obj:`dpnp.hypot` : Given the "legs" of a right triangle, return its hypotenuse.
Examples
--------
>>> import dpnp as np
>>> a = np.ones(10)
>>> np.reduce_hypot(a)
array(3.16227766)
>>> np.sqrt(np.sum(np.square(a)))
array(3.16227766)
"""

dpt_array = dpnp.get_usm_ndarray(x)
result = dpnp_array._create_from_usm_ndarray(
dpt.reduce_hypot(dpt_array, axis=axis, dtype=dtype, keepdims=keepdims)
)

return dpnp.get_result_array(result, out, casting="same_kind")


def rsqrt(
x,
/,
Expand Down
13 changes: 10 additions & 3 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,17 @@ def assert_dtype_allclose(
list_64bit_types = [numpy.float64, numpy.complex128]
is_inexact = lambda x: dpnp.issubdtype(x.dtype, dpnp.inexact)
if is_inexact(dpnp_arr) or is_inexact(numpy_arr):
tol = 8 * max(
dpnp.finfo(dpnp_arr).resolution,
numpy.finfo(numpy_arr.dtype).resolution,
tol_dpnp = (
dpnp.finfo(dpnp_arr).resolution
if is_inexact(dpnp_arr)
else -dpnp.inf
)
tol_numpy = (
numpy.finfo(numpy_arr.dtype).resolution
if is_inexact(numpy_arr)
else -dpnp.inf
)
tol = 8 * max(tol_dpnp, tol_numpy)
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
if check_type:
numpy_arr_dtype = numpy_arr.dtype
Expand Down
84 changes: 84 additions & 0 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,90 @@ def test_invalid_out(self, out):
assert_raises(TypeError, numpy.hypot, a.asnumpy(), 2, out)


class TestLogSumExp:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
@pytest.mark.parametrize("keepdims", [True, False])
def test_logsumexp(self, dtype, axis, keepdims):
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
res = dpnp.logsumexp(a, axis=axis, keepdims=keepdims)
exp_dtype = dpnp.default_float_type(a.device)
exp = numpy.logaddexp.reduce(
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
)

assert_dtype_allclose(res, exp)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
@pytest.mark.parametrize("keepdims", [True, False])
def test_logsumexp_out(self, dtype, axis, keepdims):
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
exp_dtype = dpnp.default_float_type(a.device)
exp = numpy.logaddexp.reduce(
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
)
dpnp_out = dpnp.empty(exp.shape, dtype=exp_dtype)
res = dpnp.logsumexp(a, axis=axis, out=dpnp_out, keepdims=keepdims)

assert res is dpnp_out
assert_dtype_allclose(res, exp)

@pytest.mark.parametrize(
"in_dtype", get_all_dtypes(no_bool=True, no_complex=True)
)
@pytest.mark.parametrize("out_dtype", get_all_dtypes(no_bool=True))
def test_logsumexp_dtype(self, in_dtype, out_dtype):
a = dpnp.ones(100, dtype=in_dtype)
res = dpnp.logsumexp(a, dtype=out_dtype)
exp = numpy.logaddexp.reduce(dpnp.asnumpy(a))
exp = exp.astype(out_dtype)

assert_allclose(res, exp, rtol=1e-06)


class TestReduceHypot:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
@pytest.mark.parametrize("keepdims", [True, False])
def test_reduce_hypot(self, dtype, axis, keepdims):
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
res = dpnp.reduce_hypot(a, axis=axis, keepdims=keepdims)
exp_dtype = dpnp.default_float_type(a.device)
exp = numpy.hypot.reduce(
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
)

assert_dtype_allclose(res, exp)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
@pytest.mark.parametrize("axis", [None, 2, -1, (0, 1)])
@pytest.mark.parametrize("keepdims", [True, False])
def test_reduce_hypot_out(self, dtype, axis, keepdims):
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
exp_dtype = dpnp.default_float_type(a.device)
exp = numpy.hypot.reduce(
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
)
dpnp_out = dpnp.empty(exp.shape, dtype=exp_dtype)
res = dpnp.reduce_hypot(a, axis=axis, out=dpnp_out, keepdims=keepdims)

assert res is dpnp_out
assert_dtype_allclose(res, exp)

@pytest.mark.parametrize(
"in_dtype", get_all_dtypes(no_bool=True, no_complex=True)
)
@pytest.mark.parametrize("out_dtype", get_all_dtypes(no_bool=True))
def test_reduce_hypot_dtype(self, in_dtype, out_dtype):
a = dpnp.ones(99, dtype=in_dtype)
res = dpnp.reduce_hypot(a, dtype=out_dtype)
exp = numpy.hypot.reduce(dpnp.asnumpy(a))
exp = exp.astype(out_dtype)

assert_allclose(res, exp, rtol=1e-06)


class TestMaximum:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_maximum(self, dtype):
Expand Down
34 changes: 25 additions & 9 deletions tests/test_strides.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import dpnp

from .helper import get_all_dtypes
from .helper import assert_dtype_allclose, get_all_dtypes


def _getattr(ex, str_):
Expand Down Expand Up @@ -99,17 +99,33 @@ def test_strides_1arg(func_name, dtype, shape):


@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
def test_strides_rsqrt(dtype):
a = numpy.arange(1, 11, dtype=dtype)
b = a[::2]
def test_rsqrt(dtype):
a = numpy.arange(1, 11, dtype=dtype)[::2]
dpa = dpnp.arange(1, 11, dtype=dtype)[::2]

dpa = dpnp.arange(1, 11, dtype=dtype)
dpb = dpa[::2]
result = dpnp.rsqrt(dpa)
expected = 1 / numpy.sqrt(a)
assert_dtype_allclose(result, expected)

result = dpnp.rsqrt(dpb)
expected = 1 / numpy.sqrt(b)

assert_allclose(result, expected, rtol=1e-06)
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
def test_logsumexp(dtype):
a = numpy.arange(10, dtype=dtype)[::2]
dpa = dpnp.arange(10, dtype=dtype)[::2]

result = dpnp.logsumexp(dpa)
expected = numpy.logaddexp.reduce(a)
assert_allclose(result, expected)


@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
def test_reduce_hypot(dtype):
a = numpy.arange(10, dtype=dtype)[::2]
dpa = dpnp.arange(10, dtype=dtype)[::2]

result = dpnp.reduce_hypot(dpa)
expected = numpy.hypot.reduce(a)
assert_allclose(result, expected)


@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit 20513fb

Please sign in to comment.