@@ -380,6 +380,67 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
380
380
return dpt .permute_dims (R , inv_perm )
381
381
382
382
383
+ def _empty_like_triple_orderK (X1 , X2 , X3 , dt , res_shape , usm_type , dev ):
384
+ if not isinstance (X1 , dpt .usm_ndarray ):
385
+ raise TypeError (f"Expected usm_ndarray, got { type (X1 )} " )
386
+ if not isinstance (X2 , dpt .usm_ndarray ):
387
+ raise TypeError (f"Expected usm_ndarray, got { type (X2 )} " )
388
+ if not isinstance (X3 , dpt .usm_ndarray ):
389
+ raise TypeError (f"Expected usm_ndarray, got { type (X3 )} " )
390
+ nd1 = X1 .ndim
391
+ nd2 = X2 .ndim
392
+ nd3 = X3 .ndim
393
+ if nd1 > nd2 and nd1 > nd3 and X1 .shape == res_shape :
394
+ return _empty_like_orderK (X1 , dt , usm_type , dev )
395
+ elif nd1 < nd2 and nd3 < nd2 and X2 .shape == res_shape :
396
+ return _empty_like_orderK (X2 , dt , usm_type , dev )
397
+ elif nd1 < nd3 and nd2 < nd3 and X3 .shape == res_shape :
398
+ return _empty_like_orderK (X3 , dt , usm_type , dev )
399
+ fl1 = X1 .flags
400
+ fl2 = X2 .flags
401
+ fl3 = X3 .flags
402
+ if fl1 ["C" ] or fl2 ["C" ] or fl3 ["C" ]:
403
+ return dpt .empty (
404
+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "C"
405
+ )
406
+ if fl1 ["F" ] and fl2 ["F" ] and fl3 ["F" ]:
407
+ return dpt .empty (
408
+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "F"
409
+ )
410
+ st1 = list (X1 .strides )
411
+ st2 = list (X2 .strides )
412
+ st3 = list (X3 .strides )
413
+ max_ndim = max (nd1 , nd2 , nd3 )
414
+ st1 += [0 ] * (max_ndim - len (st1 ))
415
+ st2 += [0 ] * (max_ndim - len (st2 ))
416
+ st3 += [0 ] * (max_ndim - len (st3 ))
417
+ perm = sorted (
418
+ range (max_ndim ),
419
+ key = lambda d : (
420
+ builtins .abs (st1 [d ]),
421
+ builtins .abs (st2 [d ]),
422
+ builtins .abs (st3 [d ]),
423
+ ),
424
+ reverse = True ,
425
+ )
426
+ inv_perm = sorted (range (max_ndim ), key = lambda i : perm [i ])
427
+ st1_sorted = [st1 [i ] for i in perm ]
428
+ st2_sorted = [st2 [i ] for i in perm ]
429
+ st3_sorted = [st3 [i ] for i in perm ]
430
+ sh = res_shape
431
+ sh_sorted = tuple (sh [i ] for i in perm )
432
+ R = dpt .empty (sh_sorted , dtype = dt , usm_type = usm_type , device = dev , order = "C" )
433
+ if max (min (st1_sorted ), min (st2_sorted ), min (st3_sorted )) < 0 :
434
+ sl = tuple (
435
+ slice (None , None , - 1 )
436
+ if (st1_sorted [i ] < 0 and st2_sorted [i ] < 0 and st3_sorted [i ] < 0 )
437
+ else slice (None , None , None )
438
+ for i in range (nd1 )
439
+ )
440
+ R = R [sl ]
441
+ return dpt .permute_dims (R , inv_perm )
442
+
443
+
383
444
def copy (usm_ary , order = "K" ):
384
445
"""copy(ary, order="K")
385
446
0 commit comments