|
| 1 | +import ctypes |
1 | 2 | import itertools
|
2 | 3 |
|
3 | 4 | import numpy as np
|
4 | 5 | import pytest
|
5 | 6 |
|
6 | 7 | import dpctl
|
7 | 8 | import dpctl.tensor as dpt
|
| 9 | +import dpctl.tensor._type_utils as tu |
| 10 | +import dpctl.utils |
8 | 11 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
|
9 | 12 |
|
10 | 13 | _all_dtypes = [
|
|
26 | 29 | _usm_types = ["device", "shared", "host"]
|
27 | 30 |
|
28 | 31 |
|
| 32 | +class MockDevice: |
| 33 | + def __init__(self, fp16: bool, fp64: bool): |
| 34 | + self.has_aspect_fp16 = fp16 |
| 35 | + self.has_aspect_fp64 = fp64 |
| 36 | + |
| 37 | + |
| 38 | +def _map_to_device_dtype(dt, dev): |
| 39 | + return tu._to_device_supported_dtype(dt, dev) |
| 40 | + |
| 41 | + |
| 42 | +@pytest.mark.parametrize("dtype", _all_dtypes) |
| 43 | +def test_type_utils_map_to_device_type(dtype): |
| 44 | + for fp64 in [ |
| 45 | + True, |
| 46 | + False, |
| 47 | + ]: |
| 48 | + for fp16 in [True, False]: |
| 49 | + dev = MockDevice(fp16, fp64) |
| 50 | + dt_in = dpt.dtype(dtype) |
| 51 | + dt_out = _map_to_device_dtype(dt_in, dev) |
| 52 | + assert isinstance(dt_out, dpt.dtype) |
| 53 | + |
| 54 | + |
| 55 | +def test_type_util_all_data_types(): |
| 56 | + for fp64 in [ |
| 57 | + True, |
| 58 | + False, |
| 59 | + ]: |
| 60 | + for fp16 in [True, False]: |
| 61 | + r = tu._all_data_types(fp16, fp64) |
| 62 | + assert isinstance(r, list) |
| 63 | + # 11: bool + 4 signed + 4 unsigned inegral + float32 + complex64 |
| 64 | + assert len(r) == 11 + int(fp16) + 2 * int(fp64) |
| 65 | + |
| 66 | + |
| 67 | +def test_type_util_can_cast(): |
| 68 | + for fp64 in [ |
| 69 | + True, |
| 70 | + False, |
| 71 | + ]: |
| 72 | + for fp16 in [True, False]: |
| 73 | + for from_ in _all_dtypes: |
| 74 | + for to_ in _all_dtypes: |
| 75 | + r = tu._can_cast( |
| 76 | + dpt.dtype(from_), dpt.dtype(to_), fp16, fp64 |
| 77 | + ) |
| 78 | + assert isinstance(r, bool) |
| 79 | + |
| 80 | + |
| 81 | +def test_type_utils_empty_like_orderK(): |
| 82 | + try: |
| 83 | + a = dpt.empty((10, 10), dtype=dpt.int32, order="F") |
| 84 | + except dpctl.SyclDeviceCreationError: |
| 85 | + pytest.skip("No SYCL devices available") |
| 86 | + X = tu._empty_like_orderK(a, dpt.int32, a.usm_type, a.device) |
| 87 | + assert X.flags["F"] |
| 88 | + |
| 89 | + |
| 90 | +def test_type_utils_empty_like_orderK_invalid_args(): |
| 91 | + with pytest.raises(TypeError): |
| 92 | + tu._empty_like_orderK([1, 2, 3], dpt.int32, "device", None) |
| 93 | + with pytest.raises(TypeError): |
| 94 | + tu._empty_like_pair_orderK( |
| 95 | + [1, 2, 3], |
| 96 | + ( |
| 97 | + 1, |
| 98 | + 2, |
| 99 | + 3, |
| 100 | + ), |
| 101 | + dpt.int32, |
| 102 | + "device", |
| 103 | + None, |
| 104 | + ) |
| 105 | + try: |
| 106 | + a = dpt.empty(10, dtype=dpt.int32) |
| 107 | + except dpctl.SyclDeviceCreationError: |
| 108 | + pytest.skip("No SYCL devices available") |
| 109 | + with pytest.raises(TypeError): |
| 110 | + tu._empty_like_pair_orderK( |
| 111 | + a, |
| 112 | + ( |
| 113 | + 1, |
| 114 | + 2, |
| 115 | + 3, |
| 116 | + ), |
| 117 | + dpt.int32, |
| 118 | + "device", |
| 119 | + None, |
| 120 | + ) |
| 121 | + |
| 122 | + |
| 123 | +def test_type_utils_find_buf_dtype(): |
| 124 | + def _denier_fn(dt): |
| 125 | + return False |
| 126 | + |
| 127 | + for fp64 in [ |
| 128 | + True, |
| 129 | + False, |
| 130 | + ]: |
| 131 | + for fp16 in [True, False]: |
| 132 | + dev = MockDevice(fp16, fp64) |
| 133 | + arg_dt = dpt.float64 |
| 134 | + r = tu._find_buf_dtype(arg_dt, _denier_fn, dev) |
| 135 | + assert r == ( |
| 136 | + None, |
| 137 | + None, |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +def test_type_utils_find_buf_dtype2(): |
| 142 | + def _denier_fn(dt1, dt2): |
| 143 | + return False |
| 144 | + |
| 145 | + for fp64 in [ |
| 146 | + True, |
| 147 | + False, |
| 148 | + ]: |
| 149 | + for fp16 in [True, False]: |
| 150 | + dev = MockDevice(fp16, fp64) |
| 151 | + arg1_dt = dpt.float64 |
| 152 | + arg2_dt = dpt.complex64 |
| 153 | + r = tu._find_buf_dtype2(arg1_dt, arg2_dt, _denier_fn, dev) |
| 154 | + assert r == ( |
| 155 | + None, |
| 156 | + None, |
| 157 | + None, |
| 158 | + ) |
| 159 | + |
| 160 | + |
| 161 | +def test_unary_func_arg_validation(): |
| 162 | + with pytest.raises(TypeError): |
| 163 | + dpt.abs([1, 2, 3]) |
| 164 | + try: |
| 165 | + a = dpt.arange(8) |
| 166 | + except dpctl.SyclDeviceCreationError: |
| 167 | + pytest.skip("No SYCL devices available") |
| 168 | + dpt.abs(a, order="invalid") |
| 169 | + |
| 170 | + |
| 171 | +def test_binary_func_arg_vaidation(): |
| 172 | + with pytest.raises(dpctl.utils.ExecutionPlacementError): |
| 173 | + dpt.add([1, 2, 3], 1) |
| 174 | + try: |
| 175 | + a = dpt.arange(8) |
| 176 | + except dpctl.SyclDeviceCreationError: |
| 177 | + pytest.skip("No SYCL devices available") |
| 178 | + with pytest.raises(ValueError): |
| 179 | + dpt.add(a, Ellipsis) |
| 180 | + dpt.add(a, a, order="invalid") |
| 181 | + |
| 182 | + |
29 | 183 | @pytest.mark.parametrize("dtype", _all_dtypes)
|
30 | 184 | def test_abs_out_type(dtype):
|
31 | 185 | q = get_queue_or_skip()
|
@@ -111,15 +265,7 @@ def test_abs_complex(dtype):
|
111 | 265 | def _compare_dtypes(dt, ref_dt, sycl_queue=None):
|
112 | 266 | assert isinstance(sycl_queue, dpctl.SyclQueue)
|
113 | 267 | dev = sycl_queue.sycl_device
|
114 |
| - expected_dt = ref_dt |
115 |
| - if not dev.has_aspect_fp64: |
116 |
| - if expected_dt == dpt.float64: |
117 |
| - expected_dt = dpt.float32 |
118 |
| - elif expected_dt == dpt.complex128: |
119 |
| - expected_dt = dpt.complex64 |
120 |
| - if not dev.has_aspect_fp16: |
121 |
| - if expected_dt == dpt.float16: |
122 |
| - expected_dt = dpt.float32 |
| 268 | + expected_dt = _map_to_device_dtype(ref_dt, dev) |
123 | 269 | return dt == expected_dt
|
124 | 270 |
|
125 | 271 |
|
@@ -224,22 +370,60 @@ def test_add_broadcasting():
|
224 | 370 | assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
|
225 | 371 |
|
226 | 372 |
|
227 |
| -def _map_to_device_dtype(dt, dev): |
228 |
| - if np.issubdtype(dt, np.integer): |
229 |
| - return dt |
230 |
| - if np.issubdtype(dt, np.floating): |
231 |
| - dtc = np.dtype(dt).char |
232 |
| - if dtc == "d": |
233 |
| - return dt if dev.has_aspect_fp64 else dpt.float32 |
234 |
| - elif dtc == "e": |
235 |
| - return dt if dev.has_aspect_fp16 else dpt.float32 |
236 |
| - return dt |
237 |
| - if np.issubdtype(dt, np.complexfloating): |
238 |
| - dtc = np.dtype(dt).char |
239 |
| - if dtc == "D": |
240 |
| - return dt if dev.has_aspect_fp64 else dpt.complex64 |
241 |
| - return dt |
242 |
| - return dt |
| 373 | +@pytest.mark.parametrize("arr_dt", _all_dtypes) |
| 374 | +def test_add_python_scalar(arr_dt): |
| 375 | + q = get_queue_or_skip() |
| 376 | + skip_if_dtype_not_supported(arr_dt, q) |
| 377 | + |
| 378 | + X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q) |
| 379 | + py_zeros = ( |
| 380 | + bool(0), |
| 381 | + int(0), |
| 382 | + float(0), |
| 383 | + complex(0), |
| 384 | + np.float32(0), |
| 385 | + ctypes.c_int(0), |
| 386 | + ) |
| 387 | + for sc in py_zeros: |
| 388 | + R = dpt.add(X, sc) |
| 389 | + assert isinstance(R, dpt.usm_ndarray) |
| 390 | + R = dpt.add(sc, X) |
| 391 | + assert isinstance(R, dpt.usm_ndarray) |
| 392 | + |
| 393 | + |
| 394 | +class MockArray: |
| 395 | + def __init__(self, arr): |
| 396 | + self.data_ = arr |
| 397 | + |
| 398 | + @property |
| 399 | + def __sycl_usm_array_interface__(self): |
| 400 | + return self.data_.__sycl_usm_array_interface__ |
| 401 | + |
| 402 | + |
| 403 | +def test_add_mock_array(): |
| 404 | + get_queue_or_skip() |
| 405 | + a = dpt.arange(10) |
| 406 | + b = dpt.ones(10) |
| 407 | + c = MockArray(b) |
| 408 | + r = dpt.add(a, c) |
| 409 | + assert isinstance(r, dpt.usm_ndarray) |
| 410 | + |
| 411 | + |
| 412 | +def test_add_canary_mock_array(): |
| 413 | + get_queue_or_skip() |
| 414 | + a = dpt.arange(10) |
| 415 | + |
| 416 | + class Canary: |
| 417 | + def __init__(self): |
| 418 | + pass |
| 419 | + |
| 420 | + @property |
| 421 | + def __sycl_usm_array_interface__(self): |
| 422 | + return None |
| 423 | + |
| 424 | + c = Canary() |
| 425 | + with pytest.raises(ValueError): |
| 426 | + dpt.add(a, c) |
243 | 427 |
|
244 | 428 |
|
245 | 429 | @pytest.mark.parametrize("dtype", _all_dtypes)
|
|
0 commit comments