Skip to content

Commit 1ad2396

Browse files
authored
make cat(As..., dims=Val((1,2,...)) work (#44211)
1 parent 4abb47f commit 1ad2396

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

base/abstractarray.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,13 +1712,15 @@ end
17121712
_cs(d, a, b) = (a == b ? a : throw(DimensionMismatch(
17131713
"mismatch in dimension $d (expected $a got $b)")))
17141714

1715-
function dims2cat(::Val{n}) where {n}
1716-
n <= 0 && throw(ArgumentError("cat dimension must be a positive integer, but got $n"))
1717-
ntuple(i -> (i == n), Val(n))
1715+
function dims2cat(::Val{dims}) where dims
1716+
if any((0), dims)
1717+
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
1718+
end
1719+
ntuple(in(dims), maximum(dims))
17181720
end
17191721

17201722
function dims2cat(dims)
1721-
if any(dims .<= 0)
1723+
if any((0), dims)
17221724
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
17231725
end
17241726
ntuple(in(dims), maximum(dims))

test/abstractarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ function test_cat(::Type{TestAbstractArray})
732732
@test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2)
733733
cat3v(As) = cat(As...; dims=Val(3))
734734
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
735+
@test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4)
735736
end
736737

737738
function test_ind2sub(::Type{TestAbstractArray})

0 commit comments

Comments
 (0)