@@ -372,72 +372,86 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
372
372
)
373
373
374
374
375
+ def _concat_axis_None (arrays ):
376
+ "Implementation of concat(arrays, axis=None)."
377
+ res_dtype , res_usm_type , exec_q = _arrays_validation (
378
+ arrays , check_ndim = False
379
+ )
380
+ res_shape = 0
381
+ for array in arrays :
382
+ res_shape += array .size
383
+ res = dpt .empty (
384
+ res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
385
+ )
386
+
387
+ hev_list = []
388
+ fill_start = 0
389
+ for array in arrays :
390
+ fill_end = fill_start + array .size
391
+ if array .flags .c_contiguous :
392
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
393
+ src = dpt .reshape (array , - 1 ),
394
+ dst = res [fill_start :fill_end ],
395
+ sycl_queue = exec_q ,
396
+ )
397
+ else :
398
+ hev , _ = ti ._copy_usm_ndarray_for_reshape (
399
+ src = array ,
400
+ dst = res [fill_start :fill_end ],
401
+ shift = 0 ,
402
+ sycl_queue = exec_q ,
403
+ )
404
+ fill_start = fill_end
405
+ hev_list .append (hev )
406
+
407
+ dpctl .SyclEvent .wait_for (hev_list )
408
+ return res
409
+
410
+
375
411
def concat (arrays , axis = 0 ):
376
412
"""
377
413
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
378
414
379
415
Joins a sequence of arrays along an existing axis.
380
416
"""
381
417
if axis is None :
382
- res_dtype , res_usm_type , exec_q = _arrays_validation (
383
- arrays , check_ndim = False
384
- )
385
- res_shape = 0
386
- for array in arrays :
387
- res_shape += array .size
388
- res = dpt .empty (
389
- res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
390
- )
418
+ return _concat_axis_None (arrays )
391
419
392
- hev_list = []
393
- fill_start = 0
394
- for array in arrays :
395
- fill_end = fill_start + array .size
396
- hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
397
- src = dpt .reshape (array , - 1 ),
398
- dst = res [fill_start :fill_end ],
399
- sycl_queue = exec_q ,
400
- )
401
- fill_start = fill_end
402
- hev_list .append (hev )
420
+ res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
421
+ n = len (arrays )
422
+ X0 = arrays [0 ]
403
423
404
- dpctl .SyclEvent .wait_for (hev_list )
405
- else :
406
- res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
407
- n = len (arrays )
408
- X0 = arrays [0 ]
424
+ axis = normalize_axis_index (axis , X0 .ndim )
425
+ X0_shape = X0 .shape
426
+ _check_same_shapes (X0_shape , axis , n , arrays )
409
427
410
- axis = normalize_axis_index ( axis , X0 . ndim )
411
- X0_shape = X0 . shape
412
- _check_same_shapes ( X0_shape , axis , n , arrays )
428
+ res_shape_axis = 0
429
+ for X in arrays :
430
+ res_shape_axis = res_shape_axis + X . shape [ axis ]
413
431
414
- res_shape_axis = 0
415
- for X in arrays :
416
- res_shape_axis = res_shape_axis + X . shape [ axis ]
432
+ res_shape = tuple (
433
+ X0_shape [ i ] if i != axis else res_shape_axis for i in range ( X0 . ndim )
434
+ )
417
435
418
- res_shape = tuple (
419
- X0_shape [ i ] if i != axis else res_shape_axis for i in range ( X0 . ndim )
420
- )
436
+ res = dpt . empty (
437
+ res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
438
+ )
421
439
422
- res = dpt .empty (
423
- res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
440
+ hev_list = []
441
+ fill_start = 0
442
+ for i in range (n ):
443
+ fill_end = fill_start + arrays [i ].shape [axis ]
444
+ c_shapes_copy = tuple (
445
+ np .s_ [fill_start :fill_end ] if j == axis else np .s_ [:]
446
+ for j in range (X0 .ndim )
424
447
)
448
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
449
+ src = arrays [i ], dst = res [c_shapes_copy ], sycl_queue = exec_q
450
+ )
451
+ fill_start = fill_end
452
+ hev_list .append (hev )
425
453
426
- hev_list = []
427
- fill_start = 0
428
- for i in range (n ):
429
- fill_end = fill_start + arrays [i ].shape [axis ]
430
- c_shapes_copy = tuple (
431
- np .s_ [fill_start :fill_end ] if j == axis else np .s_ [:]
432
- for j in range (X0 .ndim )
433
- )
434
- hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
435
- src = arrays [i ], dst = res [c_shapes_copy ], sycl_queue = exec_q
436
- )
437
- fill_start = fill_end
438
- hev_list .append (hev )
439
-
440
- dpctl .SyclEvent .wait_for (hev_list )
454
+ dpctl .SyclEvent .wait_for (hev_list )
441
455
442
456
return res
443
457
0 commit comments