13
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
+ import builtins
16
17
import operator
17
18
18
19
import numpy as np
@@ -245,6 +246,23 @@ def _broadcast_shapes(sh1, sh2):
245
246
).shape
246
247
247
248
249
+ def _broadcast_strides (X_shape , X_strides , res_ndim ):
250
+ """
251
+ Broadcasts strides to match the given dimensions;
252
+ returns tuple type strides.
253
+ """
254
+ out_strides = [0 ] * res_ndim
255
+ X_shape_len = len (X_shape )
256
+ str_dim = - X_shape_len
257
+ for i in range (X_shape_len ):
258
+ shape_value = X_shape [i ]
259
+ if not shape_value == 1 :
260
+ out_strides [str_dim ] = X_strides [i ]
261
+ str_dim += 1
262
+
263
+ return tuple (out_strides )
264
+
265
+
248
266
def _copy_from_usm_ndarray_to_usm_ndarray (dst , src ):
249
267
if any (
250
268
not isinstance (arg , dpt .usm_ndarray )
@@ -267,7 +285,7 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
267
285
except ValueError as exc :
268
286
raise ValueError ("Shapes of two arrays are not compatible" ) from exc
269
287
270
- if dst .size < src .size :
288
+ if dst .size < src .size and dst . size < np . prod ( common_shape ) :
271
289
raise ValueError ("Destination is smaller " )
272
290
273
291
if len (common_shape ) > dst .ndim :
@@ -278,17 +296,127 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
278
296
common_shape = common_shape [ones_count :]
279
297
280
298
if src .ndim < len (common_shape ):
281
- new_src_strides = (0 ,) * (len (common_shape ) - src .ndim ) + src .strides
299
+ new_src_strides = _broadcast_strides (
300
+ src .shape , src .strides , len (common_shape )
301
+ )
302
+ src_same_shape = dpt .usm_ndarray (
303
+ common_shape , dtype = src .dtype , buffer = src , strides = new_src_strides
304
+ )
305
+ elif src .ndim == len (common_shape ):
306
+ new_src_strides = _broadcast_strides (
307
+ src .shape , src .strides , len (common_shape )
308
+ )
282
309
src_same_shape = dpt .usm_ndarray (
283
310
common_shape , dtype = src .dtype , buffer = src , strides = new_src_strides
284
311
)
285
312
else :
286
- src_same_shape = src
287
- src_same_shape .shape = common_shape
313
+ # since broadcasting succeeded, src.ndim is greater because of
314
+ # leading sequence of ones, so we trim it
315
+ n = len (common_shape )
316
+ new_src_strides = _broadcast_strides (
317
+ src .shape [- n :], src .strides [- n :], n
318
+ )
319
+ src_same_shape = dpt .usm_ndarray (
320
+ common_shape ,
321
+ dtype = src .dtype ,
322
+ buffer = src .usm_data ,
323
+ strides = new_src_strides ,
324
+ offset = src ._element_offset ,
325
+ )
288
326
289
327
_copy_same_shape (dst , src_same_shape )
290
328
291
329
330
+ def _empty_like_orderK (X , dt , usm_type = None , dev = None ):
331
+ """Returns empty array like `x`, using order='K'
332
+
333
+ For an array `x` that was obtained by permutation of a contiguous
334
+ array the returned array will have the same shape and the same
335
+ strides as `x`.
336
+ """
337
+ if not isinstance (X , dpt .usm_ndarray ):
338
+ raise TypeError (f"Expected usm_ndarray, got { type (X )} " )
339
+ if usm_type is None :
340
+ usm_type = X .usm_type
341
+ if dev is None :
342
+ dev = X .device
343
+ fl = X .flags
344
+ if fl ["C" ] or X .size <= 1 :
345
+ return dpt .empty_like (
346
+ X , dtype = dt , usm_type = usm_type , device = dev , order = "C"
347
+ )
348
+ elif fl ["F" ]:
349
+ return dpt .empty_like (
350
+ X , dtype = dt , usm_type = usm_type , device = dev , order = "F"
351
+ )
352
+ st = list (X .strides )
353
+ perm = sorted (
354
+ range (X .ndim ), key = lambda d : builtins .abs (st [d ]), reverse = True
355
+ )
356
+ inv_perm = sorted (range (X .ndim ), key = lambda i : perm [i ])
357
+ st_sorted = [st [i ] for i in perm ]
358
+ sh = X .shape
359
+ sh_sorted = tuple (sh [i ] for i in perm )
360
+ R = dpt .empty (sh_sorted , dtype = dt , usm_type = usm_type , device = dev , order = "C" )
361
+ if min (st_sorted ) < 0 :
362
+ sl = tuple (
363
+ slice (None , None , - 1 )
364
+ if st_sorted [i ] < 0
365
+ else slice (None , None , None )
366
+ for i in range (X .ndim )
367
+ )
368
+ R = R [sl ]
369
+ return dpt .permute_dims (R , inv_perm )
370
+
371
+
372
+ def _empty_like_pair_orderK (X1 , X2 , dt , res_shape , usm_type , dev ):
373
+ if not isinstance (X1 , dpt .usm_ndarray ):
374
+ raise TypeError (f"Expected usm_ndarray, got { type (X1 )} " )
375
+ if not isinstance (X2 , dpt .usm_ndarray ):
376
+ raise TypeError (f"Expected usm_ndarray, got { type (X2 )} " )
377
+ nd1 = X1 .ndim
378
+ nd2 = X2 .ndim
379
+ if nd1 > nd2 and X1 .shape == res_shape :
380
+ return _empty_like_orderK (X1 , dt , usm_type , dev )
381
+ elif nd1 < nd2 and X2 .shape == res_shape :
382
+ return _empty_like_orderK (X2 , dt , usm_type , dev )
383
+ fl1 = X1 .flags
384
+ fl2 = X2 .flags
385
+ if fl1 ["C" ] or fl2 ["C" ]:
386
+ return dpt .empty (
387
+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "C"
388
+ )
389
+ if fl1 ["F" ] and fl2 ["F" ]:
390
+ return dpt .empty (
391
+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "F"
392
+ )
393
+ st1 = list (X1 .strides )
394
+ st2 = list (X2 .strides )
395
+ max_ndim = max (nd1 , nd2 )
396
+ st1 += [0 ] * (max_ndim - len (st1 ))
397
+ st2 += [0 ] * (max_ndim - len (st2 ))
398
+ perm = sorted (
399
+ range (max_ndim ),
400
+ key = lambda d : (builtins .abs (st1 [d ]), builtins .abs (st2 [d ])),
401
+ reverse = True ,
402
+ )
403
+ inv_perm = sorted (range (max_ndim ), key = lambda i : perm [i ])
404
+ st1_sorted = [st1 [i ] for i in perm ]
405
+ st2_sorted = [st2 [i ] for i in perm ]
406
+ sh = res_shape
407
+ sh_sorted = tuple (sh [i ] for i in perm )
408
+ R = dpt .empty (sh_sorted , dtype = dt , usm_type = usm_type , device = dev , order = "C" )
409
+ if max (min (st1_sorted ), min (st2_sorted )) < 0 :
410
+ sl = tuple (
411
+ slice (None , None , - 1 )
412
+ if (st1_sorted [i ] < 0 and st2_sorted [i ] < 0 )
413
+ else slice (None , None , None )
414
+ for i in range (nd1 )
415
+ )
416
+ R = R [sl ]
417
+ return dpt .permute_dims (R , inv_perm )
418
+
419
+
292
420
def copy (usm_ary , order = "K" ):
293
421
"""copy(ary, order="K")
294
422
@@ -334,28 +462,15 @@ def copy(usm_ary, order="K"):
334
462
"Unrecognized value of the order keyword. "
335
463
"Recognized values are 'A', 'C', 'F', or 'K'"
336
464
)
337
- c_contig = usm_ary .flags .c_contiguous
338
- f_contig = usm_ary .flags .f_contiguous
339
- R = dpt .usm_ndarray (
340
- usm_ary .shape ,
341
- dtype = usm_ary .dtype ,
342
- buffer = usm_ary .usm_type ,
343
- order = copy_order ,
344
- buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
345
- )
346
- if order == "K" and (not c_contig and not f_contig ):
347
- original_strides = usm_ary .strides
348
- ind = sorted (
349
- range (usm_ary .ndim ),
350
- key = lambda i : abs (original_strides [i ]),
351
- reverse = True ,
352
- )
353
- new_strides = tuple (R .strides [ind [i ]] for i in ind )
465
+ if order == "K" :
466
+ R = _empty_like_orderK (usm_ary , usm_ary .dtype )
467
+ else :
354
468
R = dpt .usm_ndarray (
355
469
usm_ary .shape ,
356
470
dtype = usm_ary .dtype ,
357
- buffer = R .usm_data ,
358
- strides = new_strides ,
471
+ buffer = usm_ary .usm_type ,
472
+ order = copy_order ,
473
+ buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
359
474
)
360
475
_copy_same_shape (R , usm_ary )
361
476
return R
@@ -432,26 +547,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
432
547
"Unrecognized value of the order keyword. "
433
548
"Recognized values are 'A', 'C', 'F', or 'K'"
434
549
)
435
- R = dpt .usm_ndarray (
436
- usm_ary .shape ,
437
- dtype = target_dtype ,
438
- buffer = usm_ary .usm_type ,
439
- order = copy_order ,
440
- buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
441
- )
442
- if order == "K" and (not c_contig and not f_contig ):
443
- original_strides = usm_ary .strides
444
- ind = sorted (
445
- range (usm_ary .ndim ),
446
- key = lambda i : abs (original_strides [i ]),
447
- reverse = True ,
448
- )
449
- new_strides = tuple (R .strides [ind [i ]] for i in ind )
550
+ if order == "K" :
551
+ R = _empty_like_orderK (usm_ary , target_dtype )
552
+ else :
450
553
R = dpt .usm_ndarray (
451
554
usm_ary .shape ,
452
555
dtype = target_dtype ,
453
- buffer = R .usm_data ,
454
- strides = new_strides ,
556
+ buffer = usm_ary .usm_type ,
557
+ order = copy_order ,
558
+ buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
455
559
)
456
560
_copy_from_usm_ndarray_to_usm_ndarray (R , usm_ary )
457
561
return R
@@ -492,6 +596,8 @@ def _extract_impl(ary, ary_mask, axis=0):
492
596
dst = dpt .empty (
493
597
dst_shape , dtype = ary .dtype , usm_type = ary .usm_type , device = ary .device
494
598
)
599
+ if dst .size == 0 :
600
+ return dst
495
601
hev , _ = ti ._extract (
496
602
src = ary ,
497
603
cumsum = cumsum ,
@@ -517,7 +623,7 @@ def _nonzero_impl(ary):
517
623
mask_nelems , dtype = cumsum_dt , sycl_queue = exec_q , order = "C"
518
624
)
519
625
mask_count = ti .mask_positions (ary , cumsum , sycl_queue = exec_q )
520
- indexes_dt = ti .default_device_int_type (exec_q .sycl_device )
626
+ indexes_dt = ti .default_device_index_type (exec_q .sycl_device )
521
627
indexes = dpt .empty (
522
628
(ary .ndim , mask_count ),
523
629
dtype = indexes_dt ,
0 commit comments