31
31
_empty_like_pair_orderK ,
32
32
_find_buf_dtype ,
33
33
_find_buf_dtype2 ,
34
+ _find_inplace_dtype ,
34
35
_to_device_supported_dtype ,
35
36
)
36
37
@@ -331,11 +332,19 @@ class BinaryElementwiseFunc:
331
332
Class that implements binary element-wise functions.
332
333
"""
333
334
334
- def __init__ (self , name , result_type_resolver_fn , binary_dp_impl_fn , docs ):
335
+ def __init__ (
336
+ self ,
337
+ name ,
338
+ result_type_resolver_fn ,
339
+ binary_dp_impl_fn ,
340
+ docs ,
341
+ binary_inplace_fn = None ,
342
+ ):
335
343
self .__name__ = "BinaryElementwiseFunc"
336
344
self .name_ = name
337
345
self .result_type_resolver_fn_ = result_type_resolver_fn
338
346
self .binary_fn_ = binary_dp_impl_fn
347
+ self .binary_inplace_fn_ = binary_inplace_fn
339
348
self .__doc__ = docs
340
349
341
350
def __str__ (self ):
@@ -345,6 +354,13 @@ def __repr__(self):
345
354
return f"<BinaryElementwiseFunc '{ self .name_ } '>"
346
355
347
356
def __call__ (self , o1 , o2 , out = None , order = "K" ):
357
+ # FIXME: replace with check against base array
358
+ # when views can be identified
359
+ if o1 is out :
360
+ return self ._inplace (o1 , o2 )
361
+ elif o2 is out :
362
+ return self ._inplace (o2 , o1 )
363
+
348
364
if order not in ["K" , "C" , "F" , "A" ]:
349
365
order = "K"
350
366
q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -388,6 +404,7 @@ def __call__(self, o1, o2, out=None, order="K"):
388
404
raise TypeError (
389
405
"Shape of arguments can not be inferred. "
390
406
"Arguments are expected to be "
407
+ "lists, tuples, or both"
391
408
)
392
409
try :
393
410
res_shape = _broadcast_shape_impl (
@@ -415,7 +432,7 @@ def __call__(self, o1, o2, out=None, order="K"):
415
432
416
433
if res_dt is None :
417
434
raise TypeError (
418
- "function 'add ' does not support input types "
435
+ f "function '{ self . name_ } ' does not support input types "
419
436
f"({ o1_dtype } , { o2_dtype } ), "
420
437
"and the inputs could not be safely coerced to any "
421
438
"supported types according to the casting rule ''safe''."
@@ -631,3 +648,116 @@ def __call__(self, o1, o2, out=None, order="K"):
631
648
)
632
649
dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
633
650
return out
651
+
652
+ def _inplace (self , lhs , val ):
653
+ if self .binary_inplace_fn_ is None :
654
+ raise ValueError (
655
+ f"In-place operation not supported for ufunc '{ self .name_ } '"
656
+ )
657
+ if not isinstance (lhs , dpt .usm_ndarray ):
658
+ raise TypeError (
659
+ f"Expected dpctl.tensor.usm_ndarray, got { type (lhs )} "
660
+ )
661
+ q1 , lhs_usm_type = _get_queue_usm_type (lhs )
662
+ q2 , val_usm_type = _get_queue_usm_type (val )
663
+ if q2 is None :
664
+ exec_q = q1
665
+ usm_type = lhs_usm_type
666
+ else :
667
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
668
+ if exec_q is None :
669
+ raise ExecutionPlacementError (
670
+ "Execution placement can not be unambiguously inferred "
671
+ "from input arguments."
672
+ )
673
+ usm_type = dpctl .utils .get_coerced_usm_type (
674
+ (
675
+ lhs_usm_type ,
676
+ val_usm_type ,
677
+ )
678
+ )
679
+ dpctl .utils .validate_usm_type (usm_type , allow_none = False )
680
+ lhs_shape = _get_shape (lhs )
681
+ val_shape = _get_shape (val )
682
+ if not all (
683
+ isinstance (s , (tuple , list ))
684
+ for s in (
685
+ lhs_shape ,
686
+ val_shape ,
687
+ )
688
+ ):
689
+ raise TypeError (
690
+ "Shape of arguments can not be inferred. "
691
+ "Arguments are expected to be "
692
+ "lists, tuples, or both"
693
+ )
694
+ try :
695
+ res_shape = _broadcast_shape_impl (
696
+ [
697
+ lhs_shape ,
698
+ val_shape ,
699
+ ]
700
+ )
701
+ except ValueError :
702
+ raise ValueError (
703
+ "operands could not be broadcast together with shapes "
704
+ f"{ lhs_shape } and { val_shape } "
705
+ )
706
+ if res_shape != lhs_shape :
707
+ raise ValueError (
708
+ f"output shape { lhs_shape } does not match "
709
+ f"broadcast shape { res_shape } "
710
+ )
711
+ sycl_dev = exec_q .sycl_device
712
+ lhs_dtype = lhs .dtype
713
+ val_dtype = _get_dtype (val , sycl_dev )
714
+ if not _validate_dtype (val_dtype ):
715
+ raise ValueError ("Input operand of unsupported type" )
716
+
717
+ lhs_dtype , val_dtype = _resolve_weak_types (
718
+ lhs_dtype , val_dtype , sycl_dev
719
+ )
720
+
721
+ buf_dt = _find_inplace_dtype (
722
+ lhs_dtype , val_dtype , self .result_type_resolver_fn_ , sycl_dev
723
+ )
724
+
725
+ if buf_dt is None :
726
+ raise TypeError (
727
+ f"In-place '{ self .name_ } ' does not support input types "
728
+ f"({ lhs_dtype } , { val_dtype } ), "
729
+ "and the inputs could not be safely coerced to any "
730
+ "supported types according to the casting rule ''safe''."
731
+ )
732
+
733
+ if isinstance (val , dpt .usm_ndarray ):
734
+ rhs = val
735
+ overlap = ti ._array_overlap (lhs , rhs )
736
+ else :
737
+ rhs = dpt .asarray (val , dtype = val_dtype , sycl_queue = exec_q )
738
+ overlap = False
739
+
740
+ if buf_dt == val_dtype and overlap is False :
741
+ rhs = dpt .broadcast_to (rhs , res_shape )
742
+ ht_ , _ = self .binary_inplace_fn_ (
743
+ lhs = lhs , rhs = rhs , sycl_queue = exec_q
744
+ )
745
+ ht_ .wait ()
746
+
747
+ else :
748
+ buf = dpt .empty_like (rhs , dtype = buf_dt )
749
+ ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
750
+ src = rhs , dst = buf , sycl_queue = exec_q
751
+ )
752
+
753
+ buf = dpt .broadcast_to (buf , res_shape )
754
+ ht_ , _ = self .binary_inplace_fn_ (
755
+ lhs = lhs ,
756
+ rhs = buf ,
757
+ sycl_queue = exec_q ,
758
+ depends = [copy_ev ],
759
+ )
760
+ ht_copy_ev .wait ()
761
+ ht_ .wait ()
762
+
763
+ return lhs
0 commit comments