Skip to content

Commit cc75cd2

Browse files
authored
Merge pull request #1209 from IntelPython/feature/elementwise-functions-out-keyword
Feature/elementwise functions out keyword
2 parents 0820add + f3d5519 commit cc75cd2

File tree

6 files changed

+420
-108
lines changed

6 files changed

+420
-108
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 169 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4747
self.unary_fn_ = unary_dp_impl_fn
4848
self.__doc__ = docs
4949

50-
def __call__(self, x, order="K"):
50+
def __call__(self, x, out=None, order="K"):
5151
if not isinstance(x, dpt.usm_ndarray):
5252
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
53+
54+
if out is not None:
55+
if not isinstance(out, dpt.usm_ndarray):
56+
raise TypeError(
57+
f"output array must be of usm_ndarray type, got {type(out)}"
58+
)
59+
60+
if out.shape != x.shape:
61+
raise TypeError(
62+
"The shape of input and output arrays are inconsistent."
63+
f"Expected output shape is {x.shape}, got {out.shape}"
64+
)
65+
66+
if ti._array_overlap(x, out):
67+
raise TypeError("Input and output arrays have memory overlap")
68+
69+
if (
70+
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
71+
is None
72+
):
73+
raise TypeError(
74+
"Input and output allocation queues are not compatible"
75+
)
76+
5377
if order not in ["C", "F", "K", "A"]:
5478
order = "K"
5579
buf_dt, res_dt = _find_buf_dtype(
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
5983
raise RuntimeError
6084
exec_q = x.sycl_queue
6185
if buf_dt is None:
62-
if order == "K":
63-
r = _empty_like_orderK(x, res_dt)
86+
if out is None:
87+
if order == "K":
88+
out = _empty_like_orderK(x, res_dt)
89+
else:
90+
if order == "A":
91+
order = "F" if x.flags.f_contiguous else "C"
92+
out = dpt.empty_like(x, dtype=res_dt, order=order)
6493
else:
65-
if order == "A":
66-
order = "F" if x.flags.f_contiguous else "C"
67-
r = dpt.empty_like(x, dtype=res_dt, order=order)
94+
if res_dt != out.dtype:
95+
raise TypeError(
96+
f"Output array of type {res_dt} is needed,"
97+
f" got {out.dtype}"
98+
)
6899

69-
ht, _ = self.unary_fn_(x, r, sycl_queue=exec_q)
100+
ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q)
70101
ht.wait()
71102

72-
return r
103+
return out
73104
if order == "K":
74105
buf = _empty_like_orderK(x, buf_dt)
75106
else:
@@ -80,16 +111,22 @@ def __call__(self, x, order="K"):
80111
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
81112
src=x, dst=buf, sycl_queue=exec_q
82113
)
83-
if order == "K":
84-
r = _empty_like_orderK(buf, res_dt)
114+
if out is None:
115+
if order == "K":
116+
out = _empty_like_orderK(buf, res_dt)
117+
else:
118+
out = dpt.empty_like(buf, dtype=res_dt, order=order)
85119
else:
86-
r = dpt.empty_like(buf, dtype=res_dt, order=order)
120+
if buf_dt != out.dtype:
121+
raise TypeError(
122+
f"Output array of type {buf_dt} is needed, got {out.dtype}"
123+
)
87124

88-
ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev])
125+
ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
89126
ht_copy_ev.wait()
90127
ht.wait()
91128

92-
return r
129+
return out
93130

94131

95132
def _get_queue_usm_type(o):
@@ -281,7 +318,7 @@ def __str__(self):
281318
def __repr__(self):
282319
return f"<BinaryElementwiseFunc '{self.name_}'>"
283320

284-
def __call__(self, o1, o2, order="K"):
321+
def __call__(self, o1, o2, out=None, order="K"):
285322
if order not in ["K", "C", "F", "A"]:
286323
order = "K"
287324
q1, o1_usm_type = _get_queue_usm_type(o1)
@@ -358,6 +395,31 @@ def __call__(self, o1, o2, order="K"):
358395
"supported types according to the casting rule ''safe''."
359396
)
360397

398+
if out is not None:
399+
if not isinstance(out, dpt.usm_ndarray):
400+
raise TypeError(
401+
f"output array must be of usm_ndarray type, got {type(out)}"
402+
)
403+
404+
if out.shape != res_shape:
405+
raise TypeError(
406+
"The shape of input and output arrays are inconsistent."
407+
f"Expected output shape is {o1_shape}, got {out.shape}"
408+
)
409+
410+
if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
411+
raise TypeError("Input and output arrays have memory overlap")
412+
413+
if (
414+
dpctl.utils.get_execution_queue(
415+
(o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
416+
)
417+
is None
418+
):
419+
raise TypeError(
420+
"Input and output allocation queues are not compatible"
421+
)
422+
361423
if isinstance(o1, dpt.usm_ndarray):
362424
src1 = o1
363425
else:
@@ -368,37 +430,45 @@ def __call__(self, o1, o2, order="K"):
368430
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
369431

370432
if buf1_dt is None and buf2_dt is None:
371-
if order == "K":
372-
r = _empty_like_pair_orderK(
373-
src1, src2, res_dt, res_usm_type, exec_q
374-
)
375-
else:
376-
if order == "A":
377-
order = (
378-
"F"
379-
if all(
380-
arr.flags.f_contiguous
381-
for arr in (
382-
src1,
383-
src2,
433+
if out is None:
434+
if order == "K":
435+
out = _empty_like_pair_orderK(
436+
src1, src2, res_dt, res_usm_type, exec_q
437+
)
438+
else:
439+
if order == "A":
440+
order = (
441+
"F"
442+
if all(
443+
arr.flags.f_contiguous
444+
for arr in (
445+
src1,
446+
src2,
447+
)
384448
)
449+
else "C"
385450
)
386-
else "C"
451+
out = dpt.empty(
452+
res_shape,
453+
dtype=res_dt,
454+
usm_type=res_usm_type,
455+
sycl_queue=exec_q,
456+
order=order,
387457
)
388-
r = dpt.empty(
389-
res_shape,
390-
dtype=res_dt,
391-
usm_type=res_usm_type,
392-
sycl_queue=exec_q,
393-
order=order,
394-
)
458+
else:
459+
if res_dt != out.dtype:
460+
raise TypeError(
461+
f"Output array of type {res_dt} is needed,"
462+
f"got {out.dtype}"
463+
)
464+
395465
src1 = dpt.broadcast_to(src1, res_shape)
396466
src2 = dpt.broadcast_to(src2, res_shape)
397467
ht_, _ = self.binary_fn_(
398-
src1=src1, src2=src2, dst=r, sycl_queue=exec_q
468+
src1=src1, src2=src2, dst=out, sycl_queue=exec_q
399469
)
400470
ht_.wait()
401-
return r
471+
return out
402472
elif buf1_dt is None:
403473
if order == "K":
404474
buf2 = _empty_like_orderK(src2, buf2_dt)
@@ -409,30 +479,38 @@ def __call__(self, o1, o2, order="K"):
409479
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
410480
src=src2, dst=buf2, sycl_queue=exec_q
411481
)
412-
if order == "K":
413-
r = _empty_like_pair_orderK(
414-
src1, buf2, res_dt, res_usm_type, exec_q
415-
)
482+
if out is None:
483+
if order == "K":
484+
out = _empty_like_pair_orderK(
485+
src1, buf2, res_dt, res_usm_type, exec_q
486+
)
487+
else:
488+
out = dpt.empty(
489+
res_shape,
490+
dtype=res_dt,
491+
usm_type=res_usm_type,
492+
sycl_queue=exec_q,
493+
order=order,
494+
)
416495
else:
417-
r = dpt.empty(
418-
res_shape,
419-
dtype=res_dt,
420-
usm_type=res_usm_type,
421-
sycl_queue=exec_q,
422-
order=order,
423-
)
496+
if res_dt != out.dtype:
497+
raise TypeError(
498+
f"Output array of type {res_dt} is needed,"
499+
f"got {out.dtype}"
500+
)
501+
424502
src1 = dpt.broadcast_to(src1, res_shape)
425503
buf2 = dpt.broadcast_to(buf2, res_shape)
426504
ht_, _ = self.binary_fn_(
427505
src1=src1,
428506
src2=buf2,
429-
dst=r,
507+
dst=out,
430508
sycl_queue=exec_q,
431509
depends=[copy_ev],
432510
)
433511
ht_copy_ev.wait()
434512
ht_.wait()
435-
return r
513+
return out
436514
elif buf2_dt is None:
437515
if order == "K":
438516
buf1 = _empty_like_orderK(src1, buf1_dt)
@@ -443,30 +521,38 @@ def __call__(self, o1, o2, order="K"):
443521
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
444522
src=src1, dst=buf1, sycl_queue=exec_q
445523
)
446-
if order == "K":
447-
r = _empty_like_pair_orderK(
448-
buf1, src2, res_dt, res_usm_type, exec_q
449-
)
524+
if out is None:
525+
if order == "K":
526+
out = _empty_like_pair_orderK(
527+
buf1, src2, res_dt, res_usm_type, exec_q
528+
)
529+
else:
530+
out = dpt.empty(
531+
res_shape,
532+
dtype=res_dt,
533+
usm_type=res_usm_type,
534+
sycl_queue=exec_q,
535+
order=order,
536+
)
450537
else:
451-
r = dpt.empty(
452-
res_shape,
453-
dtype=res_dt,
454-
usm_type=res_usm_type,
455-
sycl_queue=exec_q,
456-
order=order,
457-
)
538+
if res_dt != out.dtype:
539+
raise TypeError(
540+
f"Output array of type {res_dt} is needed,"
541+
f"got {out.dtype}"
542+
)
543+
458544
buf1 = dpt.broadcast_to(buf1, res_shape)
459545
src2 = dpt.broadcast_to(src2, res_shape)
460546
ht_, _ = self.binary_fn_(
461547
src1=buf1,
462548
src2=src2,
463-
dst=r,
549+
dst=out,
464550
sycl_queue=exec_q,
465551
depends=[copy_ev],
466552
)
467553
ht_copy_ev.wait()
468554
ht_.wait()
469-
return r
555+
return out
470556

471557
if order in ["K", "A"]:
472558
if src1.flags.f_contiguous and src2.flags.f_contiguous:
@@ -489,26 +575,33 @@ def __call__(self, o1, o2, order="K"):
489575
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
490576
src=src2, dst=buf2, sycl_queue=exec_q
491577
)
492-
if order == "K":
493-
r = _empty_like_pair_orderK(
494-
buf1, buf2, res_dt, res_usm_type, exec_q
495-
)
578+
if out is None:
579+
if order == "K":
580+
out = _empty_like_pair_orderK(
581+
buf1, buf2, res_dt, res_usm_type, exec_q
582+
)
583+
else:
584+
out = dpt.empty(
585+
res_shape,
586+
dtype=res_dt,
587+
usm_type=res_usm_type,
588+
sycl_queue=exec_q,
589+
order=order,
590+
)
496591
else:
497-
r = dpt.empty(
498-
res_shape,
499-
dtype=res_dt,
500-
usm_type=res_usm_type,
501-
sycl_queue=exec_q,
502-
order=order,
503-
)
592+
if res_dt != out.dtype:
593+
raise TypeError(
594+
f"Output array of type {res_dt} is needed, got {out.dtype}"
595+
)
596+
504597
buf1 = dpt.broadcast_to(buf1, res_shape)
505598
buf2 = dpt.broadcast_to(buf2, res_shape)
506599
ht_, _ = self.binary_fn_(
507600
src1=buf1,
508601
src2=buf2,
509-
dst=r,
602+
dst=out,
510603
sycl_queue=exec_q,
511604
depends=[copy1_ev, copy2_ev],
512605
)
513606
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
514-
return r
607+
return out

dpctl/tests/elementwise/test_abs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,17 @@ def test_abs_out_type(dtype):
2222
np.dtype("c16"): np.dtype("f8"),
2323
}
2424
assert dpt.abs(X).dtype == type_map[arg_dt]
25+
26+
r = dpt.empty_like(X, dtype=type_map[arg_dt])
27+
dpt.abs(X, out=r)
28+
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.abs(X)))
2529
else:
2630
assert dpt.abs(X).dtype == arg_dt
2731

32+
r = dpt.empty_like(X, dtype=arg_dt)
33+
dpt.abs(X, out=r)
34+
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.abs(X)))
35+
2836

2937
@pytest.mark.parametrize("usm_type", _usm_types)
3038
def test_abs_usm_type(usm_type):

0 commit comments

Comments
 (0)