Skip to content

Commit 29fb95f

Browse files
committed
Implements dpctl.tensor.max and dpctl.tensor.min
1 parent 7d6a560 commit 29fb95f

File tree

10 files changed

+2525
-255
lines changed

10 files changed

+2525
-255
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pybind11_add_module(${python_module_name} MODULE
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
5050
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
5151
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
52+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
5253
)
5354
set(_clang_prefix "")
5455
if (WIN32)
@@ -58,6 +59,7 @@ set_source_files_properties(
5859
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
5960
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
6061
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
62+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
6163
PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math")
6264
if (UNIX)
6365
set_source_files_properties(

dpctl/tensor/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@
158158
tanh,
159159
trunc,
160160
)
161-
from ._reduction import sum
161+
from ._reduction import max, min, sum
162162
from ._testing import allclose
163163

164164
__all__ = [
@@ -305,4 +305,6 @@
305305
"tanh",
306306
"trunc",
307307
"allclose",
308+
"max",
309+
"min",
308310
]

dpctl/tensor/_reduction.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,62 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
171171
dpctl.SyclEvent.wait_for(host_tasks_list)
172172

173173
return res
174+
175+
176+
def _same_dtype_reduction(x, axis, keepdims, func):
177+
if not isinstance(x, dpt.usm_ndarray):
178+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
179+
180+
nd = x.ndim
181+
if axis is None:
182+
red_nd = nd
183+
# case of a scalar
184+
if red_nd == 0:
185+
return dpt.copy(x)
186+
x_tmp = x
187+
res_shape = tuple()
188+
perm = list(range(nd))
189+
else:
190+
if not isinstance(axis, (tuple, list)):
191+
axis = (axis,)
192+
axis = normalize_axis_tuple(axis, nd, "axis")
193+
194+
red_nd = len(axis)
195+
# check for axis=()
196+
if red_nd == 0:
197+
return dpt.copy(x)
198+
perm = [i for i in range(nd) if i not in axis] + list(axis)
199+
x_tmp = dpt.permute_dims(x, perm)
200+
res_shape = x_tmp.shape[: nd - red_nd]
201+
202+
exec_q = x.sycl_queue
203+
res_usm_type = x.usm_type
204+
res_dtype = x.dtype
205+
206+
res = dpt.empty(
207+
res_shape,
208+
dtype=res_dtype,
209+
usm_type=res_usm_type,
210+
sycl_queue=exec_q,
211+
)
212+
hev, _ = func(
213+
src=x_tmp,
214+
trailing_dims_to_reduce=red_nd,
215+
dst=res,
216+
sycl_queue=exec_q,
217+
)
218+
219+
if keepdims:
220+
res_shape = res_shape + (1,) * red_nd
221+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
222+
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
223+
hev.wait()
224+
return res
225+
226+
227+
def max(x, axis=None, keepdims=False):
228+
return _same_dtype_reduction(x, axis, keepdims, ti._max_over_axis)
229+
230+
231+
def min(x, axis=None, keepdims=False):
232+
return _same_dtype_reduction(x, axis, keepdims, ti._min_over_axis)

0 commit comments

Comments
 (0)