Skip to content

Commit b291c97

Browse files
committed
Test commit splitting up elementwise functions
1 parent 3ad6d8b commit b291c97

File tree

7 files changed

+3553
-3400
lines changed

7 files changed

+3553
-3400
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ set(_tensor_impl_sources
4747
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
4848
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
50-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
50+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions1.cpp
51+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions2.cpp
5152
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
5253
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
5354
)
@@ -63,7 +64,8 @@ endif()
6364
set(_no_fast_math_sources
6465
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
6566
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
66-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
67+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions1.cpp
68+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions2.cpp
6769
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
6870
)
6971
foreach(_src_fn ${_no_fast_math_sources})

dpctl/tensor/libtensor/source/elementwise_functions.hpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
/// This file defines functions of dpctl.tensor._tensor_impl extensions,
2323
/// specifically functions for elementwise operations.
2424
//===----------------------------------------------------------------------===//
25-
2625
#pragma once
2726

2827
#include "dpctl4pybind11.hpp"
@@ -38,18 +37,61 @@
3837
#include "utils/offset_utils.hpp"
3938
#include "utils/type_dispatch.hpp"
4039

40+
namespace td_ns = dpctl::tensor::type_dispatch;
41+
42+
static inline py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t)
43+
{
44+
switch (dst_typenum_t) {
45+
case td_ns::typenum_t::BOOL:
46+
return py::dtype("?");
47+
case td_ns::typenum_t::INT8:
48+
return py::dtype("i1");
49+
case td_ns::typenum_t::UINT8:
50+
return py::dtype("u1");
51+
case td_ns::typenum_t::INT16:
52+
return py::dtype("i2");
53+
case td_ns::typenum_t::UINT16:
54+
return py::dtype("u2");
55+
case td_ns::typenum_t::INT32:
56+
return py::dtype("i4");
57+
case td_ns::typenum_t::UINT32:
58+
return py::dtype("u4");
59+
case td_ns::typenum_t::INT64:
60+
return py::dtype("i8");
61+
case td_ns::typenum_t::UINT64:
62+
return py::dtype("u8");
63+
case td_ns::typenum_t::HALF:
64+
return py::dtype("f2");
65+
case td_ns::typenum_t::FLOAT:
66+
return py::dtype("f4");
67+
case td_ns::typenum_t::DOUBLE:
68+
return py::dtype("f8");
69+
case td_ns::typenum_t::CFLOAT:
70+
return py::dtype("c8");
71+
case td_ns::typenum_t::CDOUBLE:
72+
return py::dtype("c16");
73+
default:
74+
throw py::value_error("Unrecognized dst_typeid");
75+
}
76+
}
77+
78+
static inline int _result_typeid(int arg_typeid, const int *fn_output_id)
79+
{
80+
if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) {
81+
throw py::value_error("Input typeid " + std::to_string(arg_typeid) +
82+
" is outside of expected bounds.");
83+
}
84+
85+
return fn_output_id[arg_typeid];
86+
}
87+
4188
namespace dpctl
4289
{
4390
namespace tensor
4491
{
4592
namespace py_internal
4693
{
4794

48-
namespace td_ns = dpctl::tensor::type_dispatch;
49-
50-
extern py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t);
51-
extern int _result_typeid(int arg_typeid, const int *fn_output_id);
52-
5395
template <typename output_typesT,
5496
typename contig_dispatchT,
5597
typename strided_dispatchT>
@@ -825,8 +867,6 @@ py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
825867
strided_fn_ev);
826868
}
827869

828-
extern void init_elementwise_functions(py::module_ m);
829-
830870
} // namespace py_internal
831871
} // namespace tensor
832872
} // namespace dpctl

0 commit comments

Comments
 (0)