|
21 | 21 | import pytest
|
22 | 22 |
|
23 | 23 | import dpctl
|
24 |
| - |
25 |
| -# import dpctl.memory as dpmem |
| 24 | +import dpctl.memory as dpm |
26 | 25 | import dpctl.tensor as dpt
|
27 | 26 | from dpctl.tensor._usmarray import Device
|
28 | 27 |
|
@@ -115,6 +114,64 @@ def test_properties():
|
115 | 114 | assert isinstance(X.ndim, numbers.Integral)
|
116 | 115 |
|
117 | 116 |
|
| 117 | +@pytest.mark.parametrize("func", [bool, float, int, complex]) |
| 118 | +@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)]) |
| 119 | +@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"]) |
| 120 | +def test_copy_scalar_with_func(func, shape, dtype): |
| 121 | + X = dpt.usm_ndarray(shape, dtype=dtype) |
| 122 | + Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape) |
| 123 | + X.usm_data.copy_from_host(Y.reshape(-1).view("|u1")) |
| 124 | + assert func(X) == func(Y) |
| 125 | + |
| 126 | + |
| 127 | +@pytest.mark.parametrize( |
| 128 | + "method", ["__bool__", "__float__", "__int__", "__complex__"] |
| 129 | +) |
| 130 | +@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)]) |
| 131 | +@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"]) |
| 132 | +def test_copy_scalar_with_method(method, shape, dtype): |
| 133 | + X = dpt.usm_ndarray(shape, dtype=dtype) |
| 134 | + Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape) |
| 135 | + X.usm_data.copy_from_host(Y.reshape(-1).view("|u1")) |
| 136 | + assert getattr(X, method)() == getattr(Y, method)() |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.parametrize("func", [bool, float, int, complex]) |
| 140 | +@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)]) |
| 141 | +def test_copy_scalar_invalid_shape(func, shape): |
| 142 | + X = dpt.usm_ndarray(shape) |
| 143 | + with pytest.raises(ValueError): |
| 144 | + func(X) |
| 145 | + |
| 146 | + |
| 147 | +@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)]) |
| 148 | +@pytest.mark.parametrize("index_dtype", ["|i8"]) |
| 149 | +def test_usm_ndarray_as_index(shape, index_dtype): |
| 150 | + X = dpt.usm_ndarray(shape, dtype=index_dtype) |
| 151 | + Xnp = np.arange(1, X.size + 1, dtype=index_dtype).reshape(shape) |
| 152 | + X.usm_data.copy_from_host(Xnp.reshape(-1).view("|u1")) |
| 153 | + Y = np.arange(X.size + 1) |
| 154 | + assert Y[X] == Y[1] |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)]) |
| 158 | +@pytest.mark.parametrize("index_dtype", ["|i8"]) |
| 159 | +def test_usm_ndarray_as_index_invalid_shape(shape, index_dtype): |
| 160 | + X = dpt.usm_ndarray(shape, dtype=index_dtype) |
| 161 | + Y = np.arange(X.size + 1) |
| 162 | + with pytest.raises(IndexError): |
| 163 | + Y[X] |
| 164 | + |
| 165 | + |
| 166 | +@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)]) |
| 167 | +@pytest.mark.parametrize("index_dtype", ["|f8"]) |
| 168 | +def test_usm_ndarray_as_index_invalid_dtype(shape, index_dtype): |
| 169 | + X = dpt.usm_ndarray(shape, dtype=index_dtype) |
| 170 | + Y = np.arange(X.size + 1) |
| 171 | + with pytest.raises(IndexError): |
| 172 | + Y[X] |
| 173 | + |
| 174 | + |
118 | 175 | @pytest.mark.parametrize(
|
119 | 176 | "ind",
|
120 | 177 | [
|
@@ -224,3 +281,95 @@ def test_slice_constructor_3d():
|
224 | 281 | assert np.array_equal(
|
225 | 282 | _to_numpy(Xusm[ind]), Xh[ind]
|
226 | 283 | ), "Failed for {}".format(ind)
|
| 284 | + |
| 285 | + |
| 286 | +@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) |
| 287 | +def test_slice_suai(usm_type): |
| 288 | + Xh = np.arange(0, 10, dtype="u1") |
| 289 | + default_device = dpctl.select_default_device() |
| 290 | + Xusm = _from_numpy(Xh, device=default_device, usm_type=usm_type) |
| 291 | + for ind in [slice(2, 3, None), slice(5, 7, None), slice(3, 9, None)]: |
| 292 | + assert np.array_equal( |
| 293 | + dpm.as_usm_memory(Xusm[ind]).copy_to_host(), Xh[ind] |
| 294 | + ), "Failed for {}".format(ind) |
| 295 | + |
| 296 | + |
| 297 | +def test_slicing_basic(): |
| 298 | + Xusm = dpt.usm_ndarray((10, 5), dtype="c16") |
| 299 | + Xusm[None] |
| 300 | + Xusm[...] |
| 301 | + Xusm[8] |
| 302 | + Xusm[-3] |
| 303 | + with pytest.raises(IndexError): |
| 304 | + Xusm[..., ...] |
| 305 | + with pytest.raises(IndexError): |
| 306 | + Xusm[1, 1, :, 1] |
| 307 | + Xusm[:, -4] |
| 308 | + with pytest.raises(IndexError): |
| 309 | + Xusm[:, -128] |
| 310 | + with pytest.raises(TypeError): |
| 311 | + Xusm[{1, 2, 3, 4, 5, 6, 7}] |
| 312 | + X = dpt.usm_ndarray(10, "u1") |
| 313 | + X.usm_data.copy_from_host(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09") |
| 314 | + int( |
| 315 | + X[X[2]] |
| 316 | + ) # check that objects with __index__ method can be used as indices |
| 317 | + Xh = dpm.as_usm_memory(X[X[2] : X[5]]).copy_to_host() |
| 318 | + Xnp = np.arange(0, 10, dtype="u1") |
| 319 | + assert np.array_equal(Xh, Xnp[Xnp[2] : Xnp[5]]) |
| 320 | + |
| 321 | + |
| 322 | +def test_ctor_invalid_shape(): |
| 323 | + with pytest.raises(TypeError): |
| 324 | + dpt.usm_ndarray(dict()) |
| 325 | + |
| 326 | + |
| 327 | +def test_ctor_invalid_order(): |
| 328 | + with pytest.raises(ValueError): |
| 329 | + dpt.usm_ndarray((5, 5, 3), order="Z") |
| 330 | + |
| 331 | + |
| 332 | +def test_ctor_buffer_kwarg(): |
| 333 | + dpt.usm_ndarray(10, buffer=b"device") |
| 334 | + with pytest.raises(ValueError): |
| 335 | + dpt.usm_ndarray(10, buffer="invalid_param") |
| 336 | + Xusm = dpt.usm_ndarray((10, 5), dtype="c16") |
| 337 | + X2 = dpt.usm_ndarray(Xusm.shape, buffer=Xusm, dtype=Xusm.dtype) |
| 338 | + assert np.array_equal( |
| 339 | + Xusm.usm_data.copy_to_host(), X2.usm_data.copy_to_host() |
| 340 | + ) |
| 341 | + with pytest.raises(ValueError): |
| 342 | + dpt.usm_ndarray(10, buffer=dict()) |
| 343 | + |
| 344 | + |
| 345 | +def test_usm_ndarray_props(): |
| 346 | + Xusm = dpt.usm_ndarray((10, 5), dtype="c16", order="F") |
| 347 | + Xusm.ndim |
| 348 | + repr(Xusm) |
| 349 | + Xusm.flags |
| 350 | + Xusm.__sycl_usm_array_interface__ |
| 351 | + Xusm.device |
| 352 | + Xusm.strides |
| 353 | + Xusm.real |
| 354 | + Xusm.imag |
| 355 | + try: |
| 356 | + dpctl.SyclQueue("cpu") |
| 357 | + except dpctl.SyclQueueCreationError: |
| 358 | + pytest.skip("Sycl device CPU was not detected") |
| 359 | + Xusm.to_device("cpu") |
| 360 | + |
| 361 | + |
| 362 | +def test_datapi_device(): |
| 363 | + X = dpt.usm_ndarray(1) |
| 364 | + dev_t = type(X.device) |
| 365 | + with pytest.raises(TypeError): |
| 366 | + dev_t() |
| 367 | + dev_t.create_device(X.device) |
| 368 | + dev_t.create_device(X.sycl_queue) |
| 369 | + dev_t.create_device(X.sycl_device) |
| 370 | + dev_t.create_device(X.sycl_device.filter_string) |
| 371 | + dev_t.create_device(None) |
| 372 | + X.device.sycl_context |
| 373 | + X.device.sycl_queue |
| 374 | + X.device.sycl_device |
| 375 | + repr(X.device) |
0 commit comments