Skip to content

Commit

Permalink
Adds tests for where behavior with scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Jun 26, 2024
1 parent 8ab41e9 commit 20929be
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions dpctl/tests/test_usm_ndarray_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ctypes
import itertools

import numpy as np
import pytest
from helper import get_queue_or_skip, skip_if_dtype_not_supported
Expand Down Expand Up @@ -522,3 +525,54 @@ def test_where_out_arg_validation():
dpt.where(condition, x1, x2, out=out_wrong_shape)
with pytest.raises(ValueError):
dpt.where(condition, x1, x2, out=out_not_writable)


@pytest.mark.parametrize("arr_dt", _all_dtypes)
def test_where_python_scalar(arr_dt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arr_dt, q)

n1, n2 = 10, 10
condition = dpt.tile(
dpt.reshape(
dpt.asarray([True, False], dtype="?", sycl_queue=q), (1, 2)
),
(n1, n2 // 2),
)
x = dpt.zeros((n1, n2), dtype=arr_dt, sycl_queue=q)
py_scalars = (
bool(0),
int(0),
float(0),
complex(0),
np.float32(0),
ctypes.c_int(0),
)
for sc in py_scalars:
r = dpt.where(condition, x, sc)
assert isinstance(r, dpt.usm_ndarray)
r = dpt.where(condition, sc, x)
assert isinstance(r, dpt.usm_ndarray)


def test_where_two_python_scalars():
get_queue_or_skip()

n1, n2 = 10, 10
condition = dpt.tile(
dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2)),
(n1, n2 // 2),
)

py_scalars = [
bool(0),
int(0),
float(0),
complex(0),
np.float32(0),
ctypes.c_int(0),
]

for sc1, sc2 in itertools.product(py_scalars, repeat=2):
r = dpt.where(condition, sc1, sc2)
assert isinstance(r, dpt.usm_ndarray)

0 comments on commit 20929be

Please sign in to comment.