Skip to content

Commit 7ae6a1a

Browse files
committed
Implements dpt.where
1 parent 37516e5 commit 7ae6a1a

File tree

7 files changed

+719
-0
lines changed

7 files changed

+719
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pybind11_add_module(${python_module_name} MODULE
4343
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
4444
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
4545
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
4647
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
4748
)
4849
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
usm_ndarray_str,
8787
)
8888
from dpctl.tensor._reshape import reshape
89+
from dpctl.tensor._search_functions import where
8990
from dpctl.tensor._usmarray import usm_ndarray
9091

9192
from ._constants import e, inf, nan, newaxis, pi
@@ -128,6 +129,7 @@
128129
"from_dlpack",
129130
"tril",
130131
"triu",
132+
"where",
131133
"dtype",
132134
"isdtype",
133135
"bool",

dpctl/tensor/_search_functions.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl
18+
import dpctl.tensor as dpt
19+
import dpctl.tensor._tensor_impl as ti
20+
from dpctl.tensor._manipulation_functions import _broadcast_shapes
21+
22+
23+
def where(condition, x1, x2):
24+
if not isinstance(condition, dpt.usm_ndarray):
25+
raise TypeError(
26+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}"
27+
)
28+
if not isinstance(x1, dpt.usm_ndarray):
29+
raise TypeError(
30+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}"
31+
)
32+
if not isinstance(x2, dpt.usm_ndarray):
33+
raise TypeError(
34+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}"
35+
)
36+
exec_q = dpctl.utils.get_execution_queue(
37+
(
38+
condition.sycl_queue,
39+
x1.sycl_queue,
40+
x2.sycl_queue,
41+
)
42+
)
43+
if exec_q is None:
44+
raise dpctl.utils.ExecutionPlacementError
45+
dst_usm_type = dpctl.utils.get_coerced_usm_type(
46+
(
47+
condition.usm_type,
48+
x1.usm_type,
49+
x2.usm_type,
50+
)
51+
)
52+
53+
x1_dtype = x1.dtype
54+
x2_dtype = x2.dtype
55+
dst_dtype = dpt.result_type(x1.dtype, x2.dtype)
56+
57+
if condition.size == 0:
58+
return dpt.asarray(
59+
(), dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q
60+
)
61+
62+
res_shape = _broadcast_shapes(condition, x1, x2)
63+
64+
deps = []
65+
wait_list = []
66+
if x1_dtype is not dst_dtype:
67+
_x1 = dpt.empty_like(x1, dtype=dst_dtype)
68+
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
69+
src=x1, dst=_x1, sycl_queue=exec_q
70+
)
71+
x1 = _x1
72+
deps.append(copy1_ev)
73+
wait_list.append(ht_copy1_ev)
74+
75+
if x2_dtype is not dst_dtype:
76+
_x2 = dpt.empty_like(x2, dtype=dst_dtype)
77+
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
78+
src=x2, dst=_x2, sycl_queue=exec_q
79+
)
80+
x2 = _x2
81+
deps.append(copy2_ev)
82+
wait_list.append(ht_copy2_ev)
83+
84+
condition = dpt.broadcast_to(condition, res_shape)
85+
x1 = dpt.broadcast_to(x1, res_shape)
86+
x2 = dpt.broadcast_to(x2, res_shape)
87+
88+
dst = dpt.empty(
89+
res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q
90+
)
91+
92+
hev, _ = ti._where(
93+
condition=condition,
94+
x1=x1,
95+
x2=x2,
96+
dst=dst,
97+
sycl_queue=exec_q,
98+
depends=deps,
99+
)
100+
wait_list.append(hev)
101+
dpctl.SyclEvent.wait_for(wait_list)
102+
103+
return dst

0 commit comments

Comments
 (0)