@@ -321,10 +321,10 @@ def roll(X, shift, axis=None):
321
321
return res
322
322
323
323
324
- def _arrays_validation (arrays ):
324
+ def _arrays_validation (arrays , check_ndim = True ):
325
325
n = len (arrays )
326
326
if n == 0 :
327
- raise TypeError ("Missing 1 required positional argument: 'arrays'" )
327
+ raise TypeError ("Missing 1 required positional argument: 'arrays'. " )
328
328
329
329
if not isinstance (arrays , (list , tuple )):
330
330
raise TypeError (f"Expected tuple or list type, got { type (arrays )} ." )
@@ -335,11 +335,11 @@ def _arrays_validation(arrays):
335
335
336
336
exec_q = dputils .get_execution_queue ([X .sycl_queue for X in arrays ])
337
337
if exec_q is None :
338
- raise ValueError ("All the input arrays must have same sycl queue" )
338
+ raise ValueError ("All the input arrays must have same sycl queue. " )
339
339
340
340
res_usm_type = dputils .get_coerced_usm_type ([X .usm_type for X in arrays ])
341
341
if res_usm_type is None :
342
- raise ValueError ("All the input arrays must have usm_type" )
342
+ raise ValueError ("All the input arrays must have usm_type. " )
343
343
344
344
X0 = arrays [0 ]
345
345
_supported_dtype (Xi .dtype for Xi in arrays )
@@ -348,13 +348,14 @@ def _arrays_validation(arrays):
348
348
for i in range (1 , n ):
349
349
res_dtype = np .promote_types (res_dtype , arrays [i ])
350
350
351
- for i in range (1 , n ):
352
- if X0 .ndim != arrays [i ].ndim :
353
- raise ValueError (
354
- "All the input arrays must have same number of dimensions, "
355
- f"but the array at index 0 has { X0 .ndim } dimension(s) and the "
356
- f"array at index { i } has { arrays [i ].ndim } dimension(s)"
357
- )
351
+ if check_ndim :
352
+ for i in range (1 , n ):
353
+ if X0 .ndim != arrays [i ].ndim :
354
+ raise ValueError (
355
+ "All the input arrays must have same number of dimensions, "
356
+ f"but the array at index 0 has { X0 .ndim } dimension(s) and "
357
+ f"the array at index { i } has { arrays [i ].ndim } dimension(s)."
358
+ )
358
359
return res_dtype , res_usm_type , exec_q
359
360
360
361
@@ -367,7 +368,7 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
367
368
"All the input array dimensions for the concatenation "
368
369
f"axis must match exactly, but along dimension { j } , the "
369
370
f"array at index 0 has size { X0j } and the array "
370
- f"at index { i } has size { Xi_shape [j ]} "
371
+ f"at index { i } has size { Xi_shape [j ]} . "
371
372
)
372
373
373
374
@@ -377,42 +378,66 @@ def concat(arrays, axis=0):
377
378
378
379
Joins a sequence of arrays along an existing axis.
379
380
"""
380
- res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
381
-
382
- n = len (arrays )
383
- X0 = arrays [0 ]
381
+ if axis is None :
382
+ res_dtype , res_usm_type , exec_q = _arrays_validation (
383
+ arrays , check_ndim = False
384
+ )
385
+ res_shape = 0
386
+ for array in arrays :
387
+ res_shape += array .size
388
+ res = dpt .empty (
389
+ res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
390
+ )
384
391
385
- axis = normalize_axis_index (axis , X0 .ndim )
386
- X0_shape = X0 .shape
387
- _check_same_shapes (X0_shape , axis , n , arrays )
392
+ hev_list = []
393
+ fill_start = 0
394
+ for array in arrays :
395
+ fill_end = fill_start + array .size
396
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
397
+ src = dpt .reshape (array , - 1 ),
398
+ dst = res [fill_start :fill_end ],
399
+ sycl_queue = exec_q ,
400
+ )
401
+ fill_start = fill_end
402
+ hev_list .append (hev )
388
403
389
- res_shape_axis = 0
390
- for X in arrays :
391
- res_shape_axis = res_shape_axis + X .shape [axis ]
404
+ dpctl .SyclEvent .wait_for (hev_list )
405
+ else :
406
+ res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
407
+ n = len (arrays )
408
+ X0 = arrays [0 ]
392
409
393
- res_shape = tuple (
394
- X0_shape [ i ] if i != axis else res_shape_axis for i in range ( X0 .ndim )
395
- )
410
+ axis = normalize_axis_index ( axis , X0 . ndim )
411
+ X0_shape = X0 .shape
412
+ _check_same_shapes ( X0_shape , axis , n , arrays )
396
413
397
- res = dpt . empty (
398
- res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
399
- )
414
+ res_shape_axis = 0
415
+ for X in arrays :
416
+ res_shape_axis = res_shape_axis + X . shape [ axis ]
400
417
401
- hev_list = []
402
- fill_start = 0
403
- for i in range (n ):
404
- fill_end = fill_start + arrays [i ].shape [axis ]
405
- c_shapes_copy = tuple (
406
- np .s_ [fill_start :fill_end ] if j == axis else np .s_ [:]
407
- for j in range (X0 .ndim )
418
+ res_shape = tuple (
419
+ X0_shape [i ] if i != axis else res_shape_axis for i in range (X0 .ndim )
408
420
)
409
- hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
410
- src = arrays [i ], dst = res [c_shapes_copy ], sycl_queue = exec_q
421
+
422
+ res = dpt .empty (
423
+ res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
411
424
)
412
- fill_start = fill_end
413
- hev_list .append (hev )
414
425
415
- dpctl .SyclEvent .wait_for (hev_list )
426
+ hev_list = []
427
+ fill_start = 0
428
+ for i in range (n ):
429
+ fill_end = fill_start + arrays [i ].shape [axis ]
430
+ c_shapes_copy = tuple (
431
+ np .s_ [fill_start :fill_end ] if j == axis else np .s_ [:]
432
+ for j in range (X0 .ndim )
433
+ )
434
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
435
+ src = arrays [i ], dst = res [c_shapes_copy ], sycl_queue = exec_q
436
+ )
437
+ fill_start = fill_end
438
+ hev_list .append (hev )
439
+
440
+ dpctl .SyclEvent .wait_for (hev_list )
416
441
417
442
return res
418
443
0 commit comments