@@ -411,10 +411,10 @@ def roll(X, shift, axis=None):
411
411
return res
412
412
413
413
414
- def _arrays_validation (arrays ):
414
+ def _arrays_validation (arrays , check_ndim = True ):
415
415
n = len (arrays )
416
416
if n == 0 :
417
- raise TypeError ("Missing 1 required positional argument: 'arrays'" )
417
+ raise TypeError ("Missing 1 required positional argument: 'arrays'. " )
418
418
419
419
if not isinstance (arrays , (list , tuple )):
420
420
raise TypeError (f"Expected tuple or list type, got { type (arrays )} ." )
@@ -425,11 +425,11 @@ def _arrays_validation(arrays):
425
425
426
426
exec_q = dputils .get_execution_queue ([X .sycl_queue for X in arrays ])
427
427
if exec_q is None :
428
- raise ValueError ("All the input arrays must have same sycl queue" )
428
+ raise ValueError ("All the input arrays must have same sycl queue. " )
429
429
430
430
res_usm_type = dputils .get_coerced_usm_type ([X .usm_type for X in arrays ])
431
431
if res_usm_type is None :
432
- raise ValueError ("All the input arrays must have usm_type" )
432
+ raise ValueError ("All the input arrays must have usm_type. " )
433
433
434
434
X0 = arrays [0 ]
435
435
_supported_dtype (Xi .dtype for Xi in arrays )
@@ -438,13 +438,14 @@ def _arrays_validation(arrays):
438
438
for i in range (1 , n ):
439
439
res_dtype = np .promote_types (res_dtype , arrays [i ])
440
440
441
- for i in range (1 , n ):
442
- if X0 .ndim != arrays [i ].ndim :
443
- raise ValueError (
444
- "All the input arrays must have same number of dimensions, "
445
- f"but the array at index 0 has { X0 .ndim } dimension(s) and the "
446
- f"array at index { i } has { arrays [i ].ndim } dimension(s)"
447
- )
441
+ if check_ndim :
442
+ for i in range (1 , n ):
443
+ if X0 .ndim != arrays [i ].ndim :
444
+ raise ValueError (
445
+ "All the input arrays must have same number of dimensions, "
446
+ f"but the array at index 0 has { X0 .ndim } dimension(s) and "
447
+ f"the array at index { i } has { arrays [i ].ndim } dimension(s)."
448
+ )
448
449
return res_dtype , res_usm_type , exec_q
449
450
450
451
@@ -457,10 +458,46 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
457
458
"All the input array dimensions for the concatenation "
458
459
f"axis must match exactly, but along dimension { j } , the "
459
460
f"array at index 0 has size { X0j } and the array "
460
- f"at index { i } has size { Xi_shape [j ]} "
461
+ f"at index { i } has size { Xi_shape [j ]} . "
461
462
)
462
463
463
464
465
+ def _concat_axis_None (arrays ):
466
+ "Implementation of concat(arrays, axis=None)."
467
+ res_dtype , res_usm_type , exec_q = _arrays_validation (
468
+ arrays , check_ndim = False
469
+ )
470
+ res_shape = 0
471
+ for array in arrays :
472
+ res_shape += array .size
473
+ res = dpt .empty (
474
+ res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
475
+ )
476
+
477
+ hev_list = []
478
+ fill_start = 0
479
+ for array in arrays :
480
+ fill_end = fill_start + array .size
481
+ if array .flags .c_contiguous :
482
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
483
+ src = dpt .reshape (array , - 1 ),
484
+ dst = res [fill_start :fill_end ],
485
+ sycl_queue = exec_q ,
486
+ )
487
+ else :
488
+ hev , _ = ti ._copy_usm_ndarray_for_reshape (
489
+ src = array ,
490
+ dst = res [fill_start :fill_end ],
491
+ shift = 0 ,
492
+ sycl_queue = exec_q ,
493
+ )
494
+ fill_start = fill_end
495
+ hev_list .append (hev )
496
+
497
+ dpctl .SyclEvent .wait_for (hev_list )
498
+ return res
499
+
500
+
464
501
def concat (arrays , axis = 0 ):
465
502
"""concat(arrays, axis)
466
503
@@ -486,8 +523,10 @@ def concat(arrays, axis=0):
486
523
of the output array is determined by USM allocation type promotion
487
524
rules.
488
525
"""
489
- res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
526
+ if axis is None :
527
+ return _concat_axis_None (arrays )
490
528
529
+ res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
491
530
n = len (arrays )
492
531
X0 = arrays [0 ]
493
532
0 commit comments