Skip to content

Commit f3d5519

Browse files
committed
use keywrod argument
1 parent d691846 commit f3d5519

File tree

5 files changed

+33
-37
lines changed

5 files changed

+33
-37
lines changed

dpctl/tests/elementwise/test_abs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def test_abs_out_type(dtype):
2323
}
2424
assert dpt.abs(X).dtype == type_map[arg_dt]
2525

26-
out = dpt.empty_like(X, dtype=type_map[arg_dt])
27-
dpt.abs(X, out)
28-
assert np.allclose(dpt.asnumpy(out), dpt.asnumpy(dpt.abs(X)))
26+
r = dpt.empty_like(X, dtype=type_map[arg_dt])
27+
dpt.abs(X, out=r)
28+
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.abs(X)))
2929
else:
3030
assert dpt.abs(X).dtype == arg_dt
3131

32-
out = dpt.empty_like(X, dtype=arg_dt)
33-
dpt.abs(X, out)
34-
assert np.allclose(dpt.asnumpy(out), dpt.asnumpy(dpt.abs(X)))
32+
r = dpt.empty_like(X, dtype=arg_dt)
33+
dpt.abs(X, out=r)
34+
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.abs(X)))
3535

3636

3737
@pytest.mark.parametrize("usm_type", _usm_types)

dpctl/tests/elementwise/test_add.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
3333
assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all()
3434
assert r.sycl_queue == ar1.sycl_queue
3535

36-
out = dpt.empty_like(ar1, dtype=r.dtype)
37-
dpt.add(ar1, ar2, out)
38-
assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all()
36+
r2 = dpt.empty_like(ar1, dtype=r.dtype)
37+
dpt.add(ar1, ar2, out=r2)
38+
assert (dpt.asnumpy(r2) == np.full(r2.shape, 2, dtype=r2.dtype)).all()
3939

4040
ar3 = dpt.ones(sz, dtype=op1_dtype)
4141
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
@@ -49,9 +49,9 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
4949
assert r.shape == ar3.shape
5050
assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all()
5151

52-
out = dpt.empty_like(ar1, dtype=r.dtype)
53-
dpt.add(ar3[::-1], ar4[::2], out)
54-
assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all()
52+
r2 = dpt.empty_like(ar1, dtype=r.dtype)
53+
dpt.add(ar3[::-1], ar4[::2], out=r2)
54+
assert (dpt.asnumpy(r2) == np.full(r2.shape, 2, dtype=r2.dtype)).all()
5555

5656

5757
@pytest.mark.parametrize("op1_usm_type", _usm_types)
@@ -131,17 +131,13 @@ def test_add_broadcasting():
131131
r2 = dpt.add(v, m)
132132
assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
133133

134-
out = dpt.empty_like(m)
135-
dpt.add(m, v, out)
136-
assert (
137-
dpt.asnumpy(out) == np.arange(1, 6, dtype="i4")[np.newaxis, :]
138-
).all()
139-
140-
out2 = dpt.empty_like(m)
141-
dpt.add(v, m, out2)
142-
assert (
143-
dpt.asnumpy(out2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]
144-
).all()
134+
r3 = dpt.empty_like(m)
135+
dpt.add(m, v, out=r3)
136+
assert (dpt.asnumpy(r3) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
137+
138+
r4 = dpt.empty_like(m)
139+
dpt.add(v, m, out=r4)
140+
assert (dpt.asnumpy(r4) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
145141

146142

147143
def test_add_broadcasting_error():

dpctl/tests/elementwise/test_cos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_cos_out_type(dtype):
2525
expected_dtype = np.cos(np.array(0, dtype=dtype)).dtype
2626
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
2727
Y = dpt.empty_like(X, dtype=expected_dtype)
28-
dpt.cos(X, Y)
28+
dpt.cos(X, out=Y)
2929
np.testing.assert_allclose(dpt.asnumpy(dpt.cos(X)), dpt.asnumpy(Y))
3030

3131

@@ -48,7 +48,7 @@ def test_cos_output(dtype):
4848
)
4949

5050
Z = dpt.empty_like(X, dtype=dtype)
51-
dpt.cos(X, Z)
51+
dpt.cos(X, out=Z)
5252

5353
np.testing.assert_allclose(
5454
dpt.asnumpy(Z), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol

dpctl/tests/elementwise/test_isfinite.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def test_isfinite_complex(dtype):
4141
Y = dpt.asarray(Ynp, sycl_queue=q)
4242
assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp))
4343

44-
out = dpt.empty_like(Y, dtype="bool")
45-
dpt.isfinite(Y, out)
46-
assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp))
44+
r = dpt.empty_like(Y, dtype="bool")
45+
dpt.isfinite(Y, out=r)
46+
assert np.array_equal(dpt.asnumpy(r)[()], np.isfinite(Ynp))
4747

4848

4949
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
@@ -60,9 +60,9 @@ def test_isfinite_floats(dtype):
6060
Y = dpt.asarray(Ynp, sycl_queue=q)
6161
assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp))
6262

63-
out = dpt.empty_like(Y, dtype="bool")
64-
dpt.isfinite(Y, out)
65-
assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp))
63+
r = dpt.empty_like(Y, dtype="bool")
64+
dpt.isfinite(Y, out=r)
65+
assert np.array_equal(dpt.asnumpy(r)[()], np.isfinite(Ynp))
6666

6767

6868
@pytest.mark.parametrize("dtype", _all_dtypes)

dpctl/tests/elementwise/test_isnan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def test_isnan_complex(dtype):
4141
Y = dpt.asarray(Ynp, sycl_queue=q)
4242
assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp))
4343

44-
out = dpt.empty_like(Y, dtype="bool")
45-
dpt.isnan(Y, out)
46-
assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp))
44+
r = dpt.empty_like(Y, dtype="bool")
45+
dpt.isnan(Y, out=r)
46+
assert np.array_equal(dpt.asnumpy(r)[()], np.isnan(Ynp))
4747

4848

4949
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
@@ -60,9 +60,9 @@ def test_isnan_floats(dtype):
6060
Y = dpt.asarray(Ynp, sycl_queue=q)
6161
assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp))
6262

63-
out = dpt.empty_like(Y, dtype="bool")
64-
dpt.isnan(Y, out)
65-
assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp))
63+
r = dpt.empty_like(Y, dtype="bool")
64+
dpt.isnan(Y, out=r)
65+
assert np.array_equal(dpt.asnumpy(r)[()], np.isnan(Ynp))
6666

6767

6868
@pytest.mark.parametrize("dtype", _all_dtypes)

0 commit comments

Comments
 (0)