Skip to content

Commit

Permalink
Add implementation of dpnp.fix function (#1971)
Browse files Browse the repository at this point in the history
* Implement dpnp.fix()

* Add tests to cover function

* Update dpnp/dpnp_iface_mathematical.py

Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>

---------

Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>
  • Loading branch information
antonwolfy and vtavana authored Aug 12, 2024
1 parent 689aeeb commit e4acd3e
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 4 deletions.
1 change: 1 addition & 0 deletions dpnp/backend/extensions/ufunc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/degrees.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fabs.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fix.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/float_power.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmin.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include "degrees.hpp"
#include "fabs.hpp"
#include "fix.hpp"
#include "float_power.hpp"
#include "fmax.hpp"
#include "fmin.hpp"
Expand All @@ -45,6 +46,7 @@ void init_elementwise_functions(py::module_ m)
{
init_degrees(m);
init_fabs(m);
init_fix(m);
init_float_power(m);
init_fmax(m);
init_fmin(m);
Expand Down
125 changes: 125 additions & 0 deletions dpnp/backend/extensions/ufunc/elementwise_functions/fix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <sycl/sycl.hpp>

#include "dpctl4pybind11.hpp"

#include "fix.hpp"
#include "kernels/elementwise_functions/fix.hpp"
#include "populate.hpp"

// include a local copy of elementwise common header from dpctl tensor:
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
// TODO: replace by including dpctl header once available
#include "../../elementwise_functions/elementwise_functions.hpp"

// dpctl tensor headers
#include "kernels/elementwise_functions/common.hpp"
#include "utils/type_dispatch.hpp"

namespace dpnp::extensions::ufunc
{
namespace py = pybind11;
namespace py_int = dpnp::extensions::py_internal;

namespace impl
{
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
namespace td_ns = dpctl::tensor::type_dispatch;

/**
* @brief A factory to define pairs of supported types for which
* sycl::fix<T> function is available.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct OutputType
{
using value_type =
typename std::disjunction<td_ns::TypeMapResultEntry<T, sycl::half>,
td_ns::TypeMapResultEntry<T, float>,
td_ns::TypeMapResultEntry<T, double>,
td_ns::DefaultResultEntry<void>>::result_type;
};

using dpnp::kernels::fix::FixFunctor;

template <typename argT,
typename resT = argT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using ContigFunctor = ew_cmn_ns::UnaryContigFunctor<argT,
resT,
FixFunctor<argT, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argTy, typename resTy, typename IndexerT>
using StridedFunctor = ew_cmn_ns::
UnaryStridedFunctor<argTy, resTy, IndexerT, FixFunctor<argTy, resTy>>;

using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;

static unary_contig_impl_fn_ptr_t fix_contig_dispatch_vector[td_ns::num_types];
static int fix_output_typeid_vector[td_ns::num_types];
static unary_strided_impl_fn_ptr_t
fix_strided_dispatch_vector[td_ns::num_types];

MACRO_POPULATE_DISPATCH_VECTORS(fix);
} // namespace impl

void init_fix(py::module_ m)
{
using arrayT = dpctl::tensor::usm_ndarray;
using event_vecT = std::vector<sycl::event>;
{
impl::populate_fix_dispatch_vectors();
using impl::fix_contig_dispatch_vector;
using impl::fix_output_typeid_vector;
using impl::fix_strided_dispatch_vector;

auto fix_pyapi = [&](const arrayT &src, const arrayT &dst,
sycl::queue &exec_q,
const event_vecT &depends = {}) {
return py_int::py_unary_ufunc(
src, dst, exec_q, depends, fix_output_typeid_vector,
fix_contig_dispatch_vector, fix_strided_dispatch_vector);
};
m.def("_fix", fix_pyapi, "", py::arg("src"), py::arg("dst"),
py::arg("sycl_queue"), py::arg("depends") = py::list());

auto fix_result_type_pyapi = [&](const py::dtype &dtype) {
return py_int::py_unary_ufunc_result_type(dtype,
fix_output_typeid_vector);
};
m.def("_fix_result_type", fix_result_type_pyapi);
}
}
} // namespace dpnp::extensions::ufunc
35 changes: 35 additions & 0 deletions dpnp/backend/extensions/ufunc/elementwise_functions/fix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <pybind11/pybind11.h>

namespace py = pybind11;

namespace dpnp::extensions::ufunc
{
void init_fix(py::module_ m);
} // namespace dpnp::extensions::ufunc
49 changes: 49 additions & 0 deletions dpnp/backend/kernels/elementwise_functions/fix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <sycl/sycl.hpp>

namespace dpnp::kernels::fix
{
template <typename argT, typename resT>
struct FixFunctor
{
// is function constant for given argT
using is_constant = typename std::false_type;
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
// do both argT and resT support subgroup store/load operation
using supports_sg_loadstore = typename std::true_type;

resT operator()(const argT &x) const
{
return (x >= 0.0) ? sycl::floor(x) : sycl::ceil(x);
}
};
} // namespace dpnp::kernels::fix
68 changes: 68 additions & 0 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"divide",
"ediff1d",
"fabs",
"fix",
"float_power",
"floor",
"floor_divide",
Expand Down Expand Up @@ -533,6 +534,7 @@ def around(x, /, decimals=0, out=None):
:obj:`dpnp.round` : Equivalent function; see for details.
:obj:`dpnp.ndarray.round` : Equivalent function.
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
Expand Down Expand Up @@ -578,6 +580,8 @@ def around(x, /, decimals=0, out=None):
--------
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
Examples
--------
Expand Down Expand Up @@ -1371,6 +1375,64 @@ def ediff1d(x1, to_end=None, to_begin=None):
)


_FIX_DOCSTRING = """
Round to nearest integer towards zero.
Round an array of floats element-wise to nearest integer towards zero.
The rounded values are returned as floats.
For full documentation refer to :obj:`numpy.fix`.
Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
An array of floats to be rounded.
out : {None, dpnp.ndarray, usm_ndarray}, optional
Output array to populate.
Array must have the correct shape and the expected data type.
Default: ``None``.
order : {"C", "F", "A", "K"}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: ``"K"``.
Returns
-------
out : dpnp.ndarray
An array with the rounded values and with the same dimensions as the input.
The returned array will have the default floating point data type for the
device where `a` is allocated.
If `out` is ``None`` then a float array is returned with the rounded values.
Otherwise the result is stored there and the return value `out` is
a reference to that array.
See Also
--------
:obj:`dpnp.round` : Round to given number of decimals.
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
:obj:`dpnp.ceil` : Return the ceiling of the input, element-wise.
Examples
--------
>>> import dpnp as np
>>> np.fix(np.array(3.14))
array(3.)
>>> np.fix(np.array(3))
array(3.)
>>> a = np.array([2.1, 2.9, -2.1, -2.9])
>>> np.fix(a)
array([ 2., 2., -2., -2.])
"""

fix = DPNPUnaryFunc(
"fix",
ufi._fix_result_type,
ufi._fix,
_FIX_DOCSTRING,
)


_FLOAT_POWER_DOCSTRING = """
Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array
`x1` with the respective element `x2_i` of the input array `x2`.
Expand Down Expand Up @@ -1504,6 +1566,8 @@ def ediff1d(x1, to_end=None, to_begin=None):
--------
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
Notes
-----
Expand Down Expand Up @@ -3048,6 +3112,7 @@ def prod(
See Also
--------
:obj:`dpnp.round` : Evenly round to the given number of decimals.
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
Expand Down Expand Up @@ -3103,6 +3168,7 @@ def prod(
:obj:`dpnp.ndarray.round` : Equivalent function.
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
Expand Down Expand Up @@ -3536,6 +3602,8 @@ def trapz(y1, x1=None, dx=1.0, axis=-1):
--------
:obj:`dpnp.floor` : Round a number to the nearest integer toward minus infinity.
:obj:`dpnp.ceil` : Round a number to the nearest integer toward infinity.
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
Examples
--------
Expand Down
2 changes: 0 additions & 2 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_inf_to_nan
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside_nan_inf

tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_fix

tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_0_{a_shape=(), b_shape=(), shape=(4, 3, 2)}::test_beta
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_1_{a_shape=(), b_shape=(), shape=(3, 2)}::test_beta
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_2_{a_shape=(), b_shape=(3, 2), shape=(4, 3, 2)}::test_beta
Expand Down
2 changes: 0 additions & 2 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_inf_to_nan
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside_nan_inf

tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_fix

tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_0_{a_shape=(), b_shape=(), shape=(4, 3, 2)}::test_beta
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_1_{a_shape=(), b_shape=(), shape=(3, 2)}::test_beta
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_2_{a_shape=(), b_shape=(3, 2), shape=(4, 3, 2)}::test_beta
Expand Down
Loading

0 comments on commit e4acd3e

Please sign in to comment.