Skip to content

Commit 6fc7666

Browse files
committed
Created a temporary copy in case of overlap for unary function
1 parent 79f4041 commit 6fc7666

File tree

9 files changed

+219
-44
lines changed

9 files changed

+219
-44
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def __call__(self, x, out=None, order="K"):
5252
if not isinstance(x, dpt.usm_ndarray):
5353
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
5454

55+
if order not in ["C", "F", "K", "A"]:
56+
order = "K"
57+
buf_dt, res_dt = _find_buf_dtype(
58+
x.dtype, self.result_type_resolver_fn_, x.sycl_device
59+
)
60+
if res_dt is None:
61+
raise RuntimeError
62+
63+
orig_out = out
5564
if out is not None:
5665
if not isinstance(out, dpt.usm_ndarray):
5766
raise TypeError(
@@ -64,8 +73,17 @@ def __call__(self, x, out=None, order="K"):
6473
f"Expected output shape is {x.shape}, got {out.shape}"
6574
)
6675

67-
if ti._array_overlap(x, out):
68-
raise TypeError("Input and output arrays have memory overlap")
76+
if res_dt != out.dtype:
77+
raise TypeError(
78+
f"Output array of type {res_dt} is needed,"
79+
f" got {out.dtype}"
80+
)
81+
82+
if buf_dt is None and ti._array_overlap(x, out):
83+
# Allocate a temporary buffer to avoid memory overlapping.
84+
# Note if `buf_dt` is not None, a temporary copy of `x` will be
85+
# created, so the array overlap check isn't needed.
86+
out = dpt.empty_like(out)
6987

7088
if (
7189
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
@@ -75,13 +93,6 @@ def __call__(self, x, out=None, order="K"):
7593
"Input and output allocation queues are not compatible"
7694
)
7795

78-
if order not in ["C", "F", "K", "A"]:
79-
order = "K"
80-
buf_dt, res_dt = _find_buf_dtype(
81-
x.dtype, self.result_type_resolver_fn_, x.sycl_device
82-
)
83-
if res_dt is None:
84-
raise RuntimeError
8596
exec_q = x.sycl_queue
8697
if buf_dt is None:
8798
if out is None:
@@ -91,17 +102,20 @@ def __call__(self, x, out=None, order="K"):
91102
if order == "A":
92103
order = "F" if x.flags.f_contiguous else "C"
93104
out = dpt.empty_like(x, dtype=res_dt, order=order)
94-
else:
95-
if res_dt != out.dtype:
96-
raise TypeError(
97-
f"Output array of type {res_dt} is needed,"
98-
f" got {out.dtype}"
99-
)
100105

101-
ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q)
102-
ht.wait()
106+
ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
107+
108+
if not (orig_out is None or orig_out is out):
109+
# Copy the out data from temporary buffer to original memory
110+
ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
111+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
112+
)
113+
ht_copy_ev.wait()
114+
out = orig_out
103115

116+
ht_unary_ev.wait()
104117
return out
118+
105119
if order == "K":
106120
buf = _empty_like_orderK(x, buf_dt)
107121
else:
@@ -117,11 +131,6 @@ def __call__(self, x, out=None, order="K"):
117131
out = _empty_like_orderK(buf, res_dt)
118132
else:
119133
out = dpt.empty_like(buf, dtype=res_dt, order=order)
120-
else:
121-
if buf_dt != out.dtype:
122-
raise TypeError(
123-
f"Output array of type {buf_dt} is needed, got {out.dtype}"
124-
)
125134

126135
ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
127136
ht_copy_ev.wait()

dpctl/tests/_numpy_warnings.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy
18+
import pytest
19+
20+
21+
@pytest.fixture
22+
def suppress_invalid_numpy_warnings():
23+
# invalid: treatment for invalid floating-point operation
24+
# (result is not an expressible number, typically indicates
25+
# that a NaN was produced)
26+
old_settings = numpy.seterr(invalid="ignore")
27+
yield
28+
numpy.seterr(**old_settings) # reset to default

dpctl/tests/conftest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,15 @@
2626
invalid_filter,
2727
valid_filter,
2828
)
29+
from _numpy_warnings import suppress_invalid_numpy_warnings
2930

3031
sys.path.append(os.path.join(os.path.dirname(__file__), "helper"))
3132

3233
# common fixtures
33-
__all__ = ["check", "device_selector", "invalid_filter", "valid_filter"]
34+
__all__ = [
35+
"check",
36+
"device_selector",
37+
"invalid_filter",
38+
"suppress_invalid_numpy_warnings",
39+
"valid_filter",
40+
]

dpctl/tests/elementwise/test_abs.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import dpctl.tensor as dpt
2323
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2424

25-
from .utils import _all_dtypes, _usm_types
25+
from .utils import _all_dtypes, _no_complex_dtypes, _usm_types
2626

2727

2828
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -113,3 +113,25 @@ def test_abs_complex(dtype):
113113
np.testing.assert_allclose(
114114
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
115115
)
116+
117+
118+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
119+
def test_abs_out_overlap(dtype):
120+
q = get_queue_or_skip()
121+
skip_if_dtype_not_supported(dtype, q)
122+
123+
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
124+
X = dpt.reshape(X, (3, 5, 4))
125+
126+
Xnp = dpt.asnumpy(X)
127+
Ynp = np.abs(Xnp, out=Xnp)
128+
129+
Y = dpt.abs(X, out=X)
130+
assert Y is X
131+
assert np.allclose(dpt.asnumpy(X), Xnp)
132+
133+
Ynp = np.abs(Xnp, out=Xnp[::-1])
134+
Y = dpt.abs(X, out=X[::-1])
135+
assert Y is not X
136+
assert np.allclose(dpt.asnumpy(X), Xnp)
137+
assert np.allclose(dpt.asnumpy(Y), Ynp)

dpctl/tests/elementwise/test_exp.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,26 @@ def test_exp_strided(dtype):
145145
atol=tol,
146146
rtol=tol,
147147
)
148+
149+
150+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
151+
def test_exp_out_overlap(dtype):
152+
q = get_queue_or_skip()
153+
skip_if_dtype_not_supported(dtype, q)
154+
155+
X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q)
156+
X = dpt.reshape(X, (3, 5))
157+
158+
Xnp = dpt.asnumpy(X)
159+
Ynp = np.exp(Xnp, out=Xnp)
160+
161+
Y = dpt.exp(X, out=X)
162+
tol = 8 * dpt.finfo(Y.dtype).resolution
163+
assert Y is X
164+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
165+
166+
Ynp = np.exp(Xnp, out=Xnp[::-1])
167+
Y = dpt.exp(X, out=X[::-1])
168+
assert Y is not X
169+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
170+
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_log.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import numpy as np
2020
import pytest
21-
from numpy.testing import assert_equal
21+
from numpy.testing import assert_allclose, assert_equal
2222

2323
import dpctl.tensor as dpt
2424
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
@@ -50,7 +50,7 @@ def test_log_output_contig(dtype):
5050
Y = dpt.log(X)
5151
tol = 8 * dpt.finfo(Y.dtype).resolution
5252

53-
np.testing.assert_allclose(dpt.asnumpy(Y), np.log(Xnp), atol=tol, rtol=tol)
53+
assert_allclose(dpt.asnumpy(Y), np.log(Xnp), atol=tol, rtol=tol)
5454

5555

5656
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
@@ -66,7 +66,7 @@ def test_log_output_strided(dtype):
6666
Y = dpt.log(X)
6767
tol = 8 * dpt.finfo(Y.dtype).resolution
6868

69-
np.testing.assert_allclose(dpt.asnumpy(Y), np.log(Xnp), atol=tol, rtol=tol)
69+
assert_allclose(dpt.asnumpy(Y), np.log(Xnp), atol=tol, rtol=tol)
7070

7171

7272
@pytest.mark.parametrize("usm_type", _usm_types)
@@ -89,7 +89,7 @@ def test_log_usm_type(usm_type):
8989
expected_Y[..., 1::2] = np.log(np.float32(10 * dpt.e))
9090
tol = 8 * dpt.finfo(Y.dtype).resolution
9191

92-
np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
92+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
9393

9494

9595
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -112,9 +112,7 @@ def test_log_order(dtype):
112112
dpt.finfo(Y.dtype).resolution,
113113
np.finfo(expected_Y.dtype).resolution,
114114
)
115-
np.testing.assert_allclose(
116-
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
117-
)
115+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
118116

119117

120118
def test_log_special_cases():
@@ -126,3 +124,27 @@ def test_log_special_cases():
126124
Xnp = dpt.asnumpy(X)
127125

128126
assert_equal(dpt.asnumpy(dpt.log(X)), np.log(Xnp))
127+
128+
129+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
130+
def test_log_out_overlap(dtype):
131+
q = get_queue_or_skip()
132+
skip_if_dtype_not_supported(dtype, q)
133+
134+
X = dpt.linspace(5, 35, 60, dtype=dtype, sycl_queue=q)
135+
X = dpt.reshape(X, (3, 5, 4))
136+
137+
Xnp = dpt.asnumpy(X)
138+
Ynp = np.log(Xnp, out=Xnp)
139+
140+
Y = dpt.log(X, out=X)
141+
assert Y is X
142+
143+
tol = 8 * dpt.finfo(Y.dtype).resolution
144+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
145+
146+
Ynp = np.log(Xnp, out=Xnp[::-1])
147+
Y = dpt.log(X, out=X[::-1])
148+
assert Y is not X
149+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
150+
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_sincos.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,6 @@ def test_sincos_errors(callable):
161161
y,
162162
)
163163

164-
x = dpt.zeros(2)
165-
y = x
166-
assert_raises_regex(
167-
TypeError, "Input and output arrays have memory overlap", callable, x, y
168-
)
169-
170164
x = dpt.zeros(2, dtype="float32")
171165
y = np.empty_like(x)
172166
assert_raises_regex(
@@ -230,3 +224,28 @@ def test_sincos_strided(dtype):
230224
atol=tol,
231225
rtol=tol,
232226
)
227+
228+
229+
@pytest.mark.parametrize(
230+
"np_call, dpt_call", [(np.sin, dpt.sin), (np.cos, dpt.cos)]
231+
)
232+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
233+
def test_sincos_out_overlap(np_call, dpt_call, dtype):
234+
q = get_queue_or_skip()
235+
skip_if_dtype_not_supported(dtype, q)
236+
237+
X = dpt.linspace(-np.pi / 2, np.pi / 2, 60, dtype=dtype, sycl_queue=q)
238+
X = dpt.reshape(X, (3, 5, 4))
239+
240+
Xnp = dpt.asnumpy(X)
241+
Ynp = np_call(Xnp, out=Xnp)
242+
243+
Y = dpt_call(X, out=X)
244+
assert Y is X
245+
assert np.allclose(dpt.asnumpy(X), Xnp)
246+
247+
Ynp = np_call(Xnp, out=Xnp[::-1])
248+
Y = dpt_call(X, out=X[::-1])
249+
assert Y is not X
250+
assert np.allclose(dpt.asnumpy(X), Xnp)
251+
assert np.allclose(dpt.asnumpy(Y), Ynp)

dpctl/tests/elementwise/test_sqrt.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import numpy as np
2020
import pytest
21-
from numpy.testing import assert_equal
21+
from numpy.testing import assert_allclose, assert_equal
2222

2323
import dpctl.tensor as dpt
2424
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
@@ -50,7 +50,7 @@ def test_sqrt_output_contig(dtype):
5050
Y = dpt.sqrt(X)
5151
tol = 8 * dpt.finfo(Y.dtype).resolution
5252

53-
np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
53+
assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
5454

5555

5656
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
@@ -66,7 +66,7 @@ def test_sqrt_output_strided(dtype):
6666
Y = dpt.sqrt(X)
6767
tol = 8 * dpt.finfo(Y.dtype).resolution
6868

69-
np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
69+
assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
7070

7171

7272
@pytest.mark.parametrize("usm_type", _usm_types)
@@ -89,7 +89,7 @@ def test_sqrt_usm_type(usm_type):
8989
expected_Y[..., 1::2] = np.sqrt(np.float32(23.0))
9090
tol = 8 * dpt.finfo(Y.dtype).resolution
9191

92-
np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
92+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
9393

9494

9595
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -112,11 +112,10 @@ def test_sqrt_order(dtype):
112112
dpt.finfo(Y.dtype).resolution,
113113
np.finfo(expected_Y.dtype).resolution,
114114
)
115-
np.testing.assert_allclose(
116-
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
117-
)
115+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
118116

119117

118+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
120119
def test_sqrt_special_cases():
121120
q = get_queue_or_skip()
122121

@@ -126,3 +125,27 @@ def test_sqrt_special_cases():
126125
Xnp = dpt.asnumpy(X)
127126

128127
assert_equal(dpt.asnumpy(dpt.sqrt(X)), np.sqrt(Xnp))
128+
129+
130+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
131+
def test_sqrt_out_overlap(dtype):
132+
q = get_queue_or_skip()
133+
skip_if_dtype_not_supported(dtype, q)
134+
135+
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
136+
X = dpt.reshape(X, (3, 5, 4))
137+
138+
Xnp = dpt.asnumpy(X)
139+
Ynp = np.sqrt(Xnp, out=Xnp)
140+
141+
Y = dpt.sqrt(X, out=X)
142+
assert Y is X
143+
144+
tol = 8 * dpt.finfo(Y.dtype).resolution
145+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
146+
147+
Ynp = np.sqrt(Xnp, out=Xnp[::-1])
148+
Y = dpt.sqrt(X, out=X[::-1])
149+
assert Y is not X
150+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
151+
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)