@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
47
47
self .unary_fn_ = unary_dp_impl_fn
48
48
self .__doc__ = docs
49
49
50
- def __call__ (self , x , order = "K" ):
50
+ def __call__ (self , x , out = None , order = "K" ):
51
51
if not isinstance (x , dpt .usm_ndarray ):
52
52
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
53
+
54
+ if out is not None :
55
+ if not isinstance (out , dpt .usm_ndarray ):
56
+ raise TypeError (
57
+ f"output array must be of usm_ndarray type, got { type (out )} "
58
+ )
59
+
60
+ if out .shape != x .shape :
61
+ raise TypeError (
62
+ "The shape of input and output arrays are inconsistent."
63
+ f"Expected output shape is { x .shape } , got { out .shape } "
64
+ )
65
+
66
+ if ti ._array_overlap (x , out ):
67
+ raise TypeError ("Input and output arrays have memory overlap" )
68
+
69
+ if (
70
+ dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
71
+ is None
72
+ ):
73
+ raise TypeError (
74
+ "Input and output allocation queues are not compatible"
75
+ )
76
+
53
77
if order not in ["C" , "F" , "K" , "A" ]:
54
78
order = "K"
55
79
buf_dt , res_dt = _find_buf_dtype (
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
59
83
raise RuntimeError
60
84
exec_q = x .sycl_queue
61
85
if buf_dt is None :
62
- if order == "K" :
63
- r = _empty_like_orderK (x , res_dt )
86
+ if out is None :
87
+ if order == "K" :
88
+ out = _empty_like_orderK (x , res_dt )
89
+ else :
90
+ if order == "A" :
91
+ order = "F" if x .flags .f_contiguous else "C"
92
+ out = dpt .empty_like (x , dtype = res_dt , order = order )
64
93
else :
65
- if order == "A" :
66
- order = "F" if x .flags .f_contiguous else "C"
67
- r = dpt .empty_like (x , dtype = res_dt , order = order )
94
+ if res_dt != out .dtype :
95
+ raise TypeError (
96
+ f"Expected output array of type { res_dt } is supported"
97
+ f", got { out .dtype } "
98
+ )
68
99
69
- ht , _ = self .unary_fn_ (x , r , sycl_queue = exec_q )
100
+ ht , _ = self .unary_fn_ (x , out , sycl_queue = exec_q )
70
101
ht .wait ()
71
102
72
- return r
103
+ return out
73
104
if order == "K" :
74
105
buf = _empty_like_orderK (x , buf_dt )
75
106
else :
@@ -80,15 +111,22 @@ def __call__(self, x, order="K"):
80
111
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
81
112
src = x , dst = buf , sycl_queue = exec_q
82
113
)
83
- if order == "K" :
84
- r = _empty_like_orderK (buf , res_dt )
114
+ if out is None :
115
+ if order == "K" :
116
+ out = _empty_like_orderK (buf , res_dt )
117
+ else :
118
+ out = dpt .empty_like (buf , dtype = res_dt , order = order )
85
119
else :
86
- r = dpt .empty_like (buf , dtype = res_dt , order = order )
120
+ if buf_dt != out .dtype :
121
+ raise TypeError (
122
+ f"Expected output array of type { buf_dt } is supported,"
123
+ f"got { out .dtype } "
124
+ )
87
125
88
- ht , _ = self .unary_fn_ (buf , r , sycl_queue = exec_q , depends = [copy_ev ])
126
+ ht , _ = self .unary_fn_ (buf , out , sycl_queue = exec_q , depends = [copy_ev ])
89
127
ht .wait ()
90
128
91
- return r
129
+ return out
92
130
93
131
94
132
def _get_queue_usm_type (o ):
@@ -280,7 +318,7 @@ def __str__(self):
280
318
def __repr__ (self ):
281
319
return f"<BinaryElementwiseFunc '{ self .name_ } '>"
282
320
283
- def __call__ (self , o1 , o2 , order = "K" ):
321
+ def __call__ (self , o1 , o2 , out = None , order = "K" ):
284
322
if order not in ["K" , "C" , "F" , "A" ]:
285
323
order = "K"
286
324
q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -357,6 +395,31 @@ def __call__(self, o1, o2, order="K"):
357
395
"supported types according to the casting rule ''safe''."
358
396
)
359
397
398
+ if out is not None :
399
+ if not isinstance (out , dpt .usm_ndarray ):
400
+ raise TypeError (
401
+ f"output array must be of usm_ndarray type, got { type (out )} "
402
+ )
403
+
404
+ if out .shape != o1_shape or out .shape != o2_shape :
405
+ raise TypeError (
406
+ "The shape of input and output arrays are inconsistent."
407
+ f"Expected output shape is { o1_shape } , got { out .shape } "
408
+ )
409
+
410
+ if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
411
+ raise TypeError ("Input and output arrays have memory overlap" )
412
+
413
+ if (
414
+ dpctl .utils .get_execution_queue (
415
+ (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
416
+ )
417
+ is None
418
+ ):
419
+ raise TypeError (
420
+ "Input and output allocation queues are not compatible"
421
+ )
422
+
360
423
if isinstance (o1 , dpt .usm_ndarray ):
361
424
src1 = o1
362
425
else :
@@ -367,37 +430,45 @@ def __call__(self, o1, o2, order="K"):
367
430
src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
368
431
369
432
if buf1_dt is None and buf2_dt is None :
370
- if order == "K" :
371
- r = _empty_like_pair_orderK (
372
- src1 , src2 , res_dt , res_usm_type , exec_q
373
- )
374
- else :
375
- if order == "A" :
376
- order = (
377
- "F"
378
- if all (
379
- arr .flags .f_contiguous
380
- for arr in (
381
- src1 ,
382
- src2 ,
433
+ if out is None :
434
+ if order == "K" :
435
+ out = _empty_like_pair_orderK (
436
+ src1 , src2 , res_dt , res_usm_type , exec_q
437
+ )
438
+ else :
439
+ if order == "A" :
440
+ order = (
441
+ "F"
442
+ if all (
443
+ arr .flags .f_contiguous
444
+ for arr in (
445
+ src1 ,
446
+ src2 ,
447
+ )
383
448
)
449
+ else "C"
384
450
)
385
- else "C"
451
+ out = dpt .empty (
452
+ res_shape ,
453
+ dtype = res_dt ,
454
+ usm_type = res_usm_type ,
455
+ sycl_queue = exec_q ,
456
+ order = order ,
386
457
)
387
- r = dpt . empty (
388
- res_shape ,
389
- dtype = res_dt ,
390
- usm_type = res_usm_type ,
391
- sycl_queue = exec_q ,
392
- order = order ,
393
- )
458
+ else :
459
+ if res_dt != out . dtype :
460
+ raise TypeError (
461
+ f"Output array of type { res_dt } is needed,"
462
+ f"got { out . dtype } "
463
+ )
464
+
394
465
src1 = dpt .broadcast_to (src1 , res_shape )
395
466
src2 = dpt .broadcast_to (src2 , res_shape )
396
467
ht_ , _ = self .binary_fn_ (
397
- src1 = src1 , src2 = src2 , dst = r , sycl_queue = exec_q
468
+ src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
398
469
)
399
470
ht_ .wait ()
400
- return r
471
+ return out
401
472
elif buf1_dt is None :
402
473
if order == "K" :
403
474
buf2 = _empty_like_orderK (src2 , buf2_dt )
@@ -408,30 +479,38 @@ def __call__(self, o1, o2, order="K"):
408
479
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
409
480
src = src2 , dst = buf2 , sycl_queue = exec_q
410
481
)
411
- if order == "K" :
412
- r = _empty_like_pair_orderK (
413
- src1 , buf2 , res_dt , res_usm_type , exec_q
414
- )
482
+ if out is None :
483
+ if order == "K" :
484
+ out = _empty_like_pair_orderK (
485
+ src1 , buf2 , res_dt , res_usm_type , exec_q
486
+ )
487
+ else :
488
+ out = dpt .empty (
489
+ res_shape ,
490
+ dtype = res_dt ,
491
+ usm_type = res_usm_type ,
492
+ sycl_queue = exec_q ,
493
+ order = order ,
494
+ )
415
495
else :
416
- r = dpt .empty (
417
- res_shape ,
418
- dtype = res_dt ,
419
- usm_type = res_usm_type ,
420
- sycl_queue = exec_q ,
421
- order = order ,
422
- )
496
+ if res_dt != out .dtype :
497
+ raise TypeError (
498
+ f"Output array of type { res_dt } is needed,"
499
+ f"got { out .dtype } "
500
+ )
501
+
423
502
src1 = dpt .broadcast_to (src1 , res_shape )
424
503
buf2 = dpt .broadcast_to (buf2 , res_shape )
425
504
ht_ , _ = self .binary_fn_ (
426
505
src1 = src1 ,
427
506
src2 = buf2 ,
428
- dst = r ,
507
+ dst = out ,
429
508
sycl_queue = exec_q ,
430
509
depends = [copy_ev ],
431
510
)
432
511
ht_copy_ev .wait ()
433
512
ht_ .wait ()
434
- return r
513
+ return out
435
514
elif buf2_dt is None :
436
515
if order == "K" :
437
516
buf1 = _empty_like_orderK (src1 , buf1_dt )
@@ -442,30 +521,38 @@ def __call__(self, o1, o2, order="K"):
442
521
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
443
522
src = src1 , dst = buf1 , sycl_queue = exec_q
444
523
)
445
- if order == "K" :
446
- r = _empty_like_pair_orderK (
447
- buf1 , src2 , res_dt , res_usm_type , exec_q
448
- )
524
+ if out is None :
525
+ if order == "K" :
526
+ out = _empty_like_pair_orderK (
527
+ buf1 , src2 , res_dt , res_usm_type , exec_q
528
+ )
529
+ else :
530
+ out = dpt .empty (
531
+ res_shape ,
532
+ dtype = res_dt ,
533
+ usm_type = res_usm_type ,
534
+ sycl_queue = exec_q ,
535
+ order = order ,
536
+ )
449
537
else :
450
- r = dpt .empty (
451
- res_shape ,
452
- dtype = res_dt ,
453
- usm_type = res_usm_type ,
454
- sycl_queue = exec_q ,
455
- order = order ,
456
- )
538
+ if res_dt != out .dtype :
539
+ raise TypeError (
540
+ f"Output array of type { res_dt } is needed,"
541
+ f"got { out .dtype } "
542
+ )
543
+
457
544
buf1 = dpt .broadcast_to (buf1 , res_shape )
458
545
src2 = dpt .broadcast_to (src2 , res_shape )
459
546
ht_ , _ = self .binary_fn_ (
460
547
src1 = buf1 ,
461
548
src2 = src2 ,
462
- dst = r ,
549
+ dst = out ,
463
550
sycl_queue = exec_q ,
464
551
depends = [copy_ev ],
465
552
)
466
553
ht_copy_ev .wait ()
467
554
ht_ .wait ()
468
- return r
555
+ return out
469
556
470
557
if order in ["K" , "A" ]:
471
558
if src1 .flags .f_contiguous and src2 .flags .f_contiguous :
@@ -488,26 +575,33 @@ def __call__(self, o1, o2, order="K"):
488
575
ht_copy2_ev , copy2_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
489
576
src = src2 , dst = buf2 , sycl_queue = exec_q
490
577
)
491
- if order == "K" :
492
- r = _empty_like_pair_orderK (
493
- buf1 , buf2 , res_dt , res_usm_type , exec_q
494
- )
578
+ if out is None :
579
+ if order == "K" :
580
+ out = _empty_like_pair_orderK (
581
+ buf1 , buf2 , res_dt , res_usm_type , exec_q
582
+ )
583
+ else :
584
+ out = dpt .empty (
585
+ res_shape ,
586
+ dtype = res_dt ,
587
+ usm_type = res_usm_type ,
588
+ sycl_queue = exec_q ,
589
+ order = order ,
590
+ )
495
591
else :
496
- r = dpt .empty (
497
- res_shape ,
498
- dtype = res_dt ,
499
- usm_type = res_usm_type ,
500
- sycl_queue = exec_q ,
501
- order = order ,
502
- )
592
+ if res_dt != out .dtype :
593
+ raise TypeError (
594
+ f"Output array of type { res_dt } is needed, got { out .dtype } "
595
+ )
596
+
503
597
buf1 = dpt .broadcast_to (buf1 , res_shape )
504
598
buf2 = dpt .broadcast_to (buf2 , res_shape )
505
599
ht_ , _ = self .binary_fn_ (
506
600
src1 = buf1 ,
507
601
src2 = buf2 ,
508
- dst = r ,
602
+ dst = out ,
509
603
sycl_queue = exec_q ,
510
604
depends = [copy1_ev , copy2_ev ],
511
605
)
512
606
dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
513
- return r
607
+ return out
0 commit comments