Skip to content

Commit f48c3f0

Browse files
committed
out keyword for elementwise functions
1 parent 7a12565 commit f48c3f0

File tree

6 files changed

+395
-76
lines changed

6 files changed

+395
-76
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 170 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4747
self.unary_fn_ = unary_dp_impl_fn
4848
self.__doc__ = docs
4949

50-
def __call__(self, x, order="K"):
50+
def __call__(self, x, out=None, order="K"):
5151
if not isinstance(x, dpt.usm_ndarray):
5252
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
53+
54+
if out is not None:
55+
if not isinstance(out, dpt.usm_ndarray):
56+
raise TypeError(
57+
f"output array must be of usm_ndarray type, got {type(out)}"
58+
)
59+
60+
if out.shape != x.shape:
61+
raise TypeError(
62+
"The shape of input and output arrays are inconsistent."
63+
f"Expected output shape is {x.shape}, got {out.shape}"
64+
)
65+
66+
if ti._array_overlap(x, out):
67+
raise TypeError("Input and output arrays have memory overlap")
68+
69+
if (
70+
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
71+
is None
72+
):
73+
raise TypeError(
74+
"Input and output allocation queues are not compatible"
75+
)
76+
5377
if order not in ["C", "F", "K", "A"]:
5478
order = "K"
5579
buf_dt, res_dt = _find_buf_dtype(
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
5983
raise RuntimeError
6084
exec_q = x.sycl_queue
6185
if buf_dt is None:
62-
if order == "K":
63-
r = _empty_like_orderK(x, res_dt)
86+
if out is None:
87+
if order == "K":
88+
out = _empty_like_orderK(x, res_dt)
89+
else:
90+
if order == "A":
91+
order = "F" if x.flags.f_contiguous else "C"
92+
out = dpt.empty_like(x, dtype=res_dt, order=order)
6493
else:
65-
if order == "A":
66-
order = "F" if x.flags.f_contiguous else "C"
67-
r = dpt.empty_like(x, dtype=res_dt, order=order)
94+
if res_dt != out.dtype:
95+
raise TypeError(
96+
f"Expected output array of type {res_dt} is supported"
97+
f", got {out.dtype}"
98+
)
6899

69-
ht, _ = self.unary_fn_(x, r, sycl_queue=exec_q)
100+
ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q)
70101
ht.wait()
71102

72-
return r
103+
return out
73104
if order == "K":
74105
buf = _empty_like_orderK(x, buf_dt)
75106
else:
@@ -80,15 +111,22 @@ def __call__(self, x, order="K"):
80111
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
81112
src=x, dst=buf, sycl_queue=exec_q
82113
)
83-
if order == "K":
84-
r = _empty_like_orderK(buf, res_dt)
114+
if out is None:
115+
if order == "K":
116+
out = _empty_like_orderK(buf, res_dt)
117+
else:
118+
out = dpt.empty_like(buf, dtype=res_dt, order=order)
85119
else:
86-
r = dpt.empty_like(buf, dtype=res_dt, order=order)
120+
if buf_dt != out.dtype:
121+
raise TypeError(
122+
f"Expected output array of type {buf_dt} is supported,"
123+
f"got {out.dtype}"
124+
)
87125

88-
ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev])
126+
ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
89127
ht.wait()
90128

91-
return r
129+
return out
92130

93131

94132
def _get_queue_usm_type(o):
@@ -280,7 +318,7 @@ def __str__(self):
280318
def __repr__(self):
281319
return f"<BinaryElementwiseFunc '{self.name_}'>"
282320

283-
def __call__(self, o1, o2, order="K"):
321+
def __call__(self, o1, o2, out=None, order="K"):
284322
if order not in ["K", "C", "F", "A"]:
285323
order = "K"
286324
q1, o1_usm_type = _get_queue_usm_type(o1)
@@ -357,6 +395,31 @@ def __call__(self, o1, o2, order="K"):
357395
"supported types according to the casting rule ''safe''."
358396
)
359397

398+
if out is not None:
399+
if not isinstance(out, dpt.usm_ndarray):
400+
raise TypeError(
401+
f"output array must be of usm_ndarray type, got {type(out)}"
402+
)
403+
404+
if out.shape != o1_shape or out.shape != o2_shape:
405+
raise TypeError(
406+
"The shape of input and output arrays are inconsistent."
407+
f"Expected output shape is {o1_shape}, got {out.shape}"
408+
)
409+
410+
if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
411+
raise TypeError("Input and output arrays have memory overlap")
412+
413+
if (
414+
dpctl.utils.get_execution_queue(
415+
(o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
416+
)
417+
is None
418+
):
419+
raise TypeError(
420+
"Input and output allocation queues are not compatible"
421+
)
422+
360423
if isinstance(o1, dpt.usm_ndarray):
361424
src1 = o1
362425
else:
@@ -367,37 +430,45 @@ def __call__(self, o1, o2, order="K"):
367430
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
368431

369432
if buf1_dt is None and buf2_dt is None:
370-
if order == "K":
371-
r = _empty_like_pair_orderK(
372-
src1, src2, res_dt, res_usm_type, exec_q
373-
)
374-
else:
375-
if order == "A":
376-
order = (
377-
"F"
378-
if all(
379-
arr.flags.f_contiguous
380-
for arr in (
381-
src1,
382-
src2,
433+
if out is None:
434+
if order == "K":
435+
out = _empty_like_pair_orderK(
436+
src1, src2, res_dt, res_usm_type, exec_q
437+
)
438+
else:
439+
if order == "A":
440+
order = (
441+
"F"
442+
if all(
443+
arr.flags.f_contiguous
444+
for arr in (
445+
src1,
446+
src2,
447+
)
383448
)
449+
else "C"
384450
)
385-
else "C"
451+
out = dpt.empty(
452+
res_shape,
453+
dtype=res_dt,
454+
usm_type=res_usm_type,
455+
sycl_queue=exec_q,
456+
order=order,
386457
)
387-
r = dpt.empty(
388-
res_shape,
389-
dtype=res_dt,
390-
usm_type=res_usm_type,
391-
sycl_queue=exec_q,
392-
order=order,
393-
)
458+
else:
459+
if res_dt != out.dtype:
460+
raise TypeError(
461+
f"Output array of type {res_dt} is needed,"
462+
f"got {out.dtype}"
463+
)
464+
394465
src1 = dpt.broadcast_to(src1, res_shape)
395466
src2 = dpt.broadcast_to(src2, res_shape)
396467
ht_, _ = self.binary_fn_(
397-
src1=src1, src2=src2, dst=r, sycl_queue=exec_q
468+
src1=src1, src2=src2, dst=out, sycl_queue=exec_q
398469
)
399470
ht_.wait()
400-
return r
471+
return out
401472
elif buf1_dt is None:
402473
if order == "K":
403474
buf2 = _empty_like_orderK(src2, buf2_dt)
@@ -408,30 +479,38 @@ def __call__(self, o1, o2, order="K"):
408479
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
409480
src=src2, dst=buf2, sycl_queue=exec_q
410481
)
411-
if order == "K":
412-
r = _empty_like_pair_orderK(
413-
src1, buf2, res_dt, res_usm_type, exec_q
414-
)
482+
if out is None:
483+
if order == "K":
484+
out = _empty_like_pair_orderK(
485+
src1, buf2, res_dt, res_usm_type, exec_q
486+
)
487+
else:
488+
out = dpt.empty(
489+
res_shape,
490+
dtype=res_dt,
491+
usm_type=res_usm_type,
492+
sycl_queue=exec_q,
493+
order=order,
494+
)
415495
else:
416-
r = dpt.empty(
417-
res_shape,
418-
dtype=res_dt,
419-
usm_type=res_usm_type,
420-
sycl_queue=exec_q,
421-
order=order,
422-
)
496+
if res_dt != out.dtype:
497+
raise TypeError(
498+
f"Output array of type {res_dt} is needed,"
499+
f"got {out.dtype}"
500+
)
501+
423502
src1 = dpt.broadcast_to(src1, res_shape)
424503
buf2 = dpt.broadcast_to(buf2, res_shape)
425504
ht_, _ = self.binary_fn_(
426505
src1=src1,
427506
src2=buf2,
428-
dst=r,
507+
dst=out,
429508
sycl_queue=exec_q,
430509
depends=[copy_ev],
431510
)
432511
ht_copy_ev.wait()
433512
ht_.wait()
434-
return r
513+
return out
435514
elif buf2_dt is None:
436515
if order == "K":
437516
buf1 = _empty_like_orderK(src1, buf1_dt)
@@ -442,30 +521,38 @@ def __call__(self, o1, o2, order="K"):
442521
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
443522
src=src1, dst=buf1, sycl_queue=exec_q
444523
)
445-
if order == "K":
446-
r = _empty_like_pair_orderK(
447-
buf1, src2, res_dt, res_usm_type, exec_q
448-
)
524+
if out is None:
525+
if order == "K":
526+
out = _empty_like_pair_orderK(
527+
buf1, src2, res_dt, res_usm_type, exec_q
528+
)
529+
else:
530+
out = dpt.empty(
531+
res_shape,
532+
dtype=res_dt,
533+
usm_type=res_usm_type,
534+
sycl_queue=exec_q,
535+
order=order,
536+
)
449537
else:
450-
r = dpt.empty(
451-
res_shape,
452-
dtype=res_dt,
453-
usm_type=res_usm_type,
454-
sycl_queue=exec_q,
455-
order=order,
456-
)
538+
if res_dt != out.dtype:
539+
raise TypeError(
540+
f"Output array of type {res_dt} is needed,"
541+
f"got {out.dtype}"
542+
)
543+
457544
buf1 = dpt.broadcast_to(buf1, res_shape)
458545
src2 = dpt.broadcast_to(src2, res_shape)
459546
ht_, _ = self.binary_fn_(
460547
src1=buf1,
461548
src2=src2,
462-
dst=r,
549+
dst=out,
463550
sycl_queue=exec_q,
464551
depends=[copy_ev],
465552
)
466553
ht_copy_ev.wait()
467554
ht_.wait()
468-
return r
555+
return out
469556

470557
if order in ["K", "A"]:
471558
if src1.flags.f_contiguous and src2.flags.f_contiguous:
@@ -488,26 +575,33 @@ def __call__(self, o1, o2, order="K"):
488575
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
489576
src=src2, dst=buf2, sycl_queue=exec_q
490577
)
491-
if order == "K":
492-
r = _empty_like_pair_orderK(
493-
buf1, buf2, res_dt, res_usm_type, exec_q
494-
)
578+
if out is None:
579+
if order == "K":
580+
out = _empty_like_pair_orderK(
581+
buf1, buf2, res_dt, res_usm_type, exec_q
582+
)
583+
else:
584+
out = dpt.empty(
585+
res_shape,
586+
dtype=res_dt,
587+
usm_type=res_usm_type,
588+
sycl_queue=exec_q,
589+
order=order,
590+
)
495591
else:
496-
r = dpt.empty(
497-
res_shape,
498-
dtype=res_dt,
499-
usm_type=res_usm_type,
500-
sycl_queue=exec_q,
501-
order=order,
502-
)
592+
if res_dt != out.dtype:
593+
raise TypeError(
594+
f"Output array of type {res_dt} is needed, got {out.dtype}"
595+
)
596+
503597
buf1 = dpt.broadcast_to(buf1, res_shape)
504598
buf2 = dpt.broadcast_to(buf2, res_shape)
505599
ht_, _ = self.binary_fn_(
506600
src1=buf1,
507601
src2=buf2,
508-
dst=r,
602+
dst=out,
509603
sycl_queue=exec_q,
510604
depends=[copy1_ev, copy2_ev],
511605
)
512606
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
513-
return r
607+
return out

dpctl/tests/elementwise/test_abs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,20 @@ def test_abs_complex(dtype):
8989
np.testing.assert_allclose(
9090
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
9191
)
92+
93+
94+
@pytest.mark.parametrize("dtype", _all_dtypes[:-2])
95+
def test_abs_out_keyword(dtype):
96+
q = get_queue_or_skip()
97+
skip_if_dtype_not_supported(dtype, q)
98+
99+
arg_dt = np.dtype(dtype)
100+
input_shape = (10, 10, 10, 10)
101+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
102+
X[..., 0::2] = 1
103+
X[..., 1::2] = 0
104+
Y = dpt.empty_like(X, dtype=arg_dt)
105+
dpt.abs(X, Y)
106+
107+
expected_Y = dpt.asnumpy(X)
108+
assert np.allclose(dpt.asnumpy(Y), expected_Y)

0 commit comments

Comments
 (0)