Skip to content

Commit a01db12

Browse files
committed
Made changes as per PR review by @oleksandr-pavlyk
1 parent 4558394 commit a01db12

File tree

7 files changed

+43
-12
lines changed

7 files changed

+43
-12
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -40,6 +40,37 @@ def _where_result_type(dt1, dt2, dev):
4040

4141

4242
def where(condition, x1, x2):
43+
"""where(condition, x1, x2)
44+
45+
Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen
46+
from `x1` or `x2` depending on `condition`.
47+
48+
Args:
49+
condition (usm_ndarray): When True yields from `x1`,
50+
and otherwise yields from `x2`.
51+
Must be compatible with `x1` and `x2` according
52+
to broadcasting rules.
53+
x1 (usm_ndarray): Array from which values are chosen when
54+
`condition` is True.
55+
Must be compatible with `condition` and `x2` according
56+
to broadcasting rules.
57+
x2 (usm_ndarray): Array from which values are chosen when
58+
`condition` is not True.
59+
Must be compatible with `condition` and `x2` according
60+
to broadcasting rules.
61+
62+
Returns:
63+
usm_ndarray:
64+
An array with elements from `x1` where `condition` is True,
65+
and elements from `x2` elsewhere.
66+
67+
The data type of the returned array is determined by applying
68+
the Type Promotion Rules to `x1` and `x2`.
69+
70+
The memory layout of the returned array is
71+
F-contiguous (column-major) when all inputs are F-contiguous,
72+
and C-contiguous (row-major) otherwise.
73+
"""
4374
if not isinstance(condition, dpt.usm_ndarray):
4475
raise TypeError(
4576
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}"

dpctl/tensor/_type_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.

dpctl/tensor/libtensor/source/where.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.
@@ -96,10 +96,10 @@ py_where(dpctl::tensor::usm_ndarray condition,
9696
bool shapes_equal(true);
9797
size_t nelems(1);
9898
for (int i = 0; i < nd; ++i) {
99-
nelems *= static_cast<size_t>(dst_shape[i]);
100-
shapes_equal = shapes_equal && (x1_shape[i] == dst_shape[i]) &&
101-
(x2_shape[i] == dst_shape[i]) &&
102-
(cond_shape[i] == dst_shape[i]);
99+
const auto &sh_i = dst_shape[i];
100+
nelems *= static_cast<size_t>(sh_i);
101+
shapes_equal = shapes_equal && (x1_shape[i] == sh_i) &&
102+
(x2_shape[i] == sh_i) && (cond_shape[i] == sh_i);
103103
}
104104

105105
if (!shapes_equal) {
@@ -127,7 +127,7 @@ py_where(dpctl::tensor::usm_ndarray condition,
127127
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
128128

129129
if (x1_typeid != x2_typeid || x1_typeid != dst_typeid) {
130-
throw py::value_error("Non-condition are not of same type.");
130+
throw py::value_error("Value arrays must have the same data type");
131131
}
132132

133133
// ensure that dst is sufficiently ample

dpctl/tensor/libtensor/source/where.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.

dpctl/tests/test_type_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)