@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
47
47
self .unary_fn_ = unary_dp_impl_fn
48
48
self .__doc__ = docs
49
49
50
- def __call__ (self , x , order = "K" ):
50
+ def __call__ (self , x , out = None , order = "K" ):
51
51
if not isinstance (x , dpt .usm_ndarray ):
52
52
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
53
+
54
+ if out is not None :
55
+ if not isinstance (out , dpt .usm_ndarray ):
56
+ raise TypeError (
57
+ f"output array must be of usm_ndarray type, got { type (out )} "
58
+ )
59
+
60
+ if out .shape != x .shape :
61
+ raise TypeError (
62
+ "The shape of input and output arrays are inconsistent."
63
+ f"Expected output shape is { x .shape } , got { out .shape } "
64
+ )
65
+
66
+ if ti ._array_overlap (x , out ):
67
+ raise TypeError ("Input and output arrays have memory overlap" )
68
+
69
+ if (
70
+ dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
71
+ is None
72
+ ):
73
+ raise TypeError (
74
+ "Input and output allocation queues are not compatible"
75
+ )
76
+
53
77
if order not in ["C" , "F" , "K" , "A" ]:
54
78
order = "K"
55
79
buf_dt , res_dt = _find_buf_dtype (
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
59
83
raise RuntimeError
60
84
exec_q = x .sycl_queue
61
85
if buf_dt is None :
62
- if order == "K" :
63
- r = _empty_like_orderK (x , res_dt )
86
+ if out is None :
87
+ if order == "K" :
88
+ out = _empty_like_orderK (x , res_dt )
89
+ else :
90
+ if order == "A" :
91
+ order = "F" if x .flags .f_contiguous else "C"
92
+ out = dpt .empty_like (x , dtype = res_dt , order = order )
64
93
else :
65
- if order == "A" :
66
- order = "F" if x .flags .f_contiguous else "C"
67
- r = dpt .empty_like (x , dtype = res_dt , order = order )
94
+ if res_dt != out .dtype :
95
+ raise TypeError (
96
+ f"Output array of type { res_dt } is needed,"
97
+ f" got { out .dtype } "
98
+ )
68
99
69
- ht , _ = self .unary_fn_ (x , r , sycl_queue = exec_q )
100
+ ht , _ = self .unary_fn_ (x , out , sycl_queue = exec_q )
70
101
ht .wait ()
71
102
72
- return r
103
+ return out
73
104
if order == "K" :
74
105
buf = _empty_like_orderK (x , buf_dt )
75
106
else :
@@ -80,16 +111,22 @@ def __call__(self, x, order="K"):
80
111
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
81
112
src = x , dst = buf , sycl_queue = exec_q
82
113
)
83
- if order == "K" :
84
- r = _empty_like_orderK (buf , res_dt )
114
+ if out is None :
115
+ if order == "K" :
116
+ out = _empty_like_orderK (buf , res_dt )
117
+ else :
118
+ out = dpt .empty_like (buf , dtype = res_dt , order = order )
85
119
else :
86
- r = dpt .empty_like (buf , dtype = res_dt , order = order )
120
+ if buf_dt != out .dtype :
121
+ raise TypeError (
122
+ f"Output array of type { buf_dt } is needed, got { out .dtype } "
123
+ )
87
124
88
- ht , _ = self .unary_fn_ (buf , r , sycl_queue = exec_q , depends = [copy_ev ])
125
+ ht , _ = self .unary_fn_ (buf , out , sycl_queue = exec_q , depends = [copy_ev ])
89
126
ht_copy_ev .wait ()
90
127
ht .wait ()
91
128
92
- return r
129
+ return out
93
130
94
131
95
132
def _get_queue_usm_type (o ):
@@ -281,7 +318,7 @@ def __str__(self):
281
318
def __repr__ (self ):
282
319
return f"<BinaryElementwiseFunc '{ self .name_ } '>"
283
320
284
- def __call__ (self , o1 , o2 , order = "K" ):
321
+ def __call__ (self , o1 , o2 , out = None , order = "K" ):
285
322
if order not in ["K" , "C" , "F" , "A" ]:
286
323
order = "K"
287
324
q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -358,6 +395,31 @@ def __call__(self, o1, o2, order="K"):
358
395
"supported types according to the casting rule ''safe''."
359
396
)
360
397
398
+ if out is not None :
399
+ if not isinstance (out , dpt .usm_ndarray ):
400
+ raise TypeError (
401
+ f"output array must be of usm_ndarray type, got { type (out )} "
402
+ )
403
+
404
+ if out .shape != res_shape :
405
+ raise TypeError (
406
+ "The shape of input and output arrays are inconsistent."
407
+ f"Expected output shape is { o1_shape } , got { out .shape } "
408
+ )
409
+
410
+ if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
411
+ raise TypeError ("Input and output arrays have memory overlap" )
412
+
413
+ if (
414
+ dpctl .utils .get_execution_queue (
415
+ (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
416
+ )
417
+ is None
418
+ ):
419
+ raise TypeError (
420
+ "Input and output allocation queues are not compatible"
421
+ )
422
+
361
423
if isinstance (o1 , dpt .usm_ndarray ):
362
424
src1 = o1
363
425
else :
@@ -368,37 +430,45 @@ def __call__(self, o1, o2, order="K"):
368
430
src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
369
431
370
432
if buf1_dt is None and buf2_dt is None :
371
- if order == "K" :
372
- r = _empty_like_pair_orderK (
373
- src1 , src2 , res_dt , res_usm_type , exec_q
374
- )
375
- else :
376
- if order == "A" :
377
- order = (
378
- "F"
379
- if all (
380
- arr .flags .f_contiguous
381
- for arr in (
382
- src1 ,
383
- src2 ,
433
+ if out is None :
434
+ if order == "K" :
435
+ out = _empty_like_pair_orderK (
436
+ src1 , src2 , res_dt , res_usm_type , exec_q
437
+ )
438
+ else :
439
+ if order == "A" :
440
+ order = (
441
+ "F"
442
+ if all (
443
+ arr .flags .f_contiguous
444
+ for arr in (
445
+ src1 ,
446
+ src2 ,
447
+ )
384
448
)
449
+ else "C"
385
450
)
386
- else "C"
451
+ out = dpt .empty (
452
+ res_shape ,
453
+ dtype = res_dt ,
454
+ usm_type = res_usm_type ,
455
+ sycl_queue = exec_q ,
456
+ order = order ,
387
457
)
388
- r = dpt . empty (
389
- res_shape ,
390
- dtype = res_dt ,
391
- usm_type = res_usm_type ,
392
- sycl_queue = exec_q ,
393
- order = order ,
394
- )
458
+ else :
459
+ if res_dt != out . dtype :
460
+ raise TypeError (
461
+ f"Output array of type { res_dt } is needed,"
462
+ f"got { out . dtype } "
463
+ )
464
+
395
465
src1 = dpt .broadcast_to (src1 , res_shape )
396
466
src2 = dpt .broadcast_to (src2 , res_shape )
397
467
ht_ , _ = self .binary_fn_ (
398
- src1 = src1 , src2 = src2 , dst = r , sycl_queue = exec_q
468
+ src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
399
469
)
400
470
ht_ .wait ()
401
- return r
471
+ return out
402
472
elif buf1_dt is None :
403
473
if order == "K" :
404
474
buf2 = _empty_like_orderK (src2 , buf2_dt )
@@ -409,30 +479,38 @@ def __call__(self, o1, o2, order="K"):
409
479
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
410
480
src = src2 , dst = buf2 , sycl_queue = exec_q
411
481
)
412
- if order == "K" :
413
- r = _empty_like_pair_orderK (
414
- src1 , buf2 , res_dt , res_usm_type , exec_q
415
- )
482
+ if out is None :
483
+ if order == "K" :
484
+ out = _empty_like_pair_orderK (
485
+ src1 , buf2 , res_dt , res_usm_type , exec_q
486
+ )
487
+ else :
488
+ out = dpt .empty (
489
+ res_shape ,
490
+ dtype = res_dt ,
491
+ usm_type = res_usm_type ,
492
+ sycl_queue = exec_q ,
493
+ order = order ,
494
+ )
416
495
else :
417
- r = dpt .empty (
418
- res_shape ,
419
- dtype = res_dt ,
420
- usm_type = res_usm_type ,
421
- sycl_queue = exec_q ,
422
- order = order ,
423
- )
496
+ if res_dt != out .dtype :
497
+ raise TypeError (
498
+ f"Output array of type { res_dt } is needed,"
499
+ f"got { out .dtype } "
500
+ )
501
+
424
502
src1 = dpt .broadcast_to (src1 , res_shape )
425
503
buf2 = dpt .broadcast_to (buf2 , res_shape )
426
504
ht_ , _ = self .binary_fn_ (
427
505
src1 = src1 ,
428
506
src2 = buf2 ,
429
- dst = r ,
507
+ dst = out ,
430
508
sycl_queue = exec_q ,
431
509
depends = [copy_ev ],
432
510
)
433
511
ht_copy_ev .wait ()
434
512
ht_ .wait ()
435
- return r
513
+ return out
436
514
elif buf2_dt is None :
437
515
if order == "K" :
438
516
buf1 = _empty_like_orderK (src1 , buf1_dt )
@@ -443,30 +521,38 @@ def __call__(self, o1, o2, order="K"):
443
521
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
444
522
src = src1 , dst = buf1 , sycl_queue = exec_q
445
523
)
446
- if order == "K" :
447
- r = _empty_like_pair_orderK (
448
- buf1 , src2 , res_dt , res_usm_type , exec_q
449
- )
524
+ if out is None :
525
+ if order == "K" :
526
+ out = _empty_like_pair_orderK (
527
+ buf1 , src2 , res_dt , res_usm_type , exec_q
528
+ )
529
+ else :
530
+ out = dpt .empty (
531
+ res_shape ,
532
+ dtype = res_dt ,
533
+ usm_type = res_usm_type ,
534
+ sycl_queue = exec_q ,
535
+ order = order ,
536
+ )
450
537
else :
451
- r = dpt .empty (
452
- res_shape ,
453
- dtype = res_dt ,
454
- usm_type = res_usm_type ,
455
- sycl_queue = exec_q ,
456
- order = order ,
457
- )
538
+ if res_dt != out .dtype :
539
+ raise TypeError (
540
+ f"Output array of type { res_dt } is needed,"
541
+ f"got { out .dtype } "
542
+ )
543
+
458
544
buf1 = dpt .broadcast_to (buf1 , res_shape )
459
545
src2 = dpt .broadcast_to (src2 , res_shape )
460
546
ht_ , _ = self .binary_fn_ (
461
547
src1 = buf1 ,
462
548
src2 = src2 ,
463
- dst = r ,
549
+ dst = out ,
464
550
sycl_queue = exec_q ,
465
551
depends = [copy_ev ],
466
552
)
467
553
ht_copy_ev .wait ()
468
554
ht_ .wait ()
469
- return r
555
+ return out
470
556
471
557
if order in ["K" , "A" ]:
472
558
if src1 .flags .f_contiguous and src2 .flags .f_contiguous :
@@ -489,26 +575,33 @@ def __call__(self, o1, o2, order="K"):
489
575
ht_copy2_ev , copy2_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
490
576
src = src2 , dst = buf2 , sycl_queue = exec_q
491
577
)
492
- if order == "K" :
493
- r = _empty_like_pair_orderK (
494
- buf1 , buf2 , res_dt , res_usm_type , exec_q
495
- )
578
+ if out is None :
579
+ if order == "K" :
580
+ out = _empty_like_pair_orderK (
581
+ buf1 , buf2 , res_dt , res_usm_type , exec_q
582
+ )
583
+ else :
584
+ out = dpt .empty (
585
+ res_shape ,
586
+ dtype = res_dt ,
587
+ usm_type = res_usm_type ,
588
+ sycl_queue = exec_q ,
589
+ order = order ,
590
+ )
496
591
else :
497
- r = dpt .empty (
498
- res_shape ,
499
- dtype = res_dt ,
500
- usm_type = res_usm_type ,
501
- sycl_queue = exec_q ,
502
- order = order ,
503
- )
592
+ if res_dt != out .dtype :
593
+ raise TypeError (
594
+ f"Output array of type { res_dt } is needed, got { out .dtype } "
595
+ )
596
+
504
597
buf1 = dpt .broadcast_to (buf1 , res_shape )
505
598
buf2 = dpt .broadcast_to (buf2 , res_shape )
506
599
ht_ , _ = self .binary_fn_ (
507
600
src1 = buf1 ,
508
601
src2 = buf2 ,
509
- dst = r ,
602
+ dst = out ,
510
603
sycl_queue = exec_q ,
511
604
depends = [copy1_ev , copy2_ev ],
512
605
)
513
606
dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
514
- return r
607
+ return out
0 commit comments