Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add insertdims #830

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ changes in `julia`.

## Supported features

* `insertdims(D; dims)` is the opposite of `dropdims` ([#45793]) (since Compat 4.16.0)

* `Compat.Fix{N}` which fixes an argument at the `N`th position ([#54653]) (since Compat 4.16.0)

* `chopprefix(s, prefix)` and `chopsuffix(s, suffix)` ([#40995]) (since Compat 4.15.0)
Expand Down Expand Up @@ -190,6 +192,7 @@ Note that you should specify the correct minimum version for `Compat` in the
[#43852]: https://github.com/JuliaLang/julia/issues/43852
[#45052]: https://github.com/JuliaLang/julia/issues/45052
[#45607]: https://github.com/JuliaLang/julia/issues/45607
[#45793]: https://github.com/JuliaLang/julia/issues/45793
[#47354]: https://github.com/JuliaLang/julia/issues/47354
[#47679]: https://github.com/JuliaLang/julia/pull/47679
[#48038]: https://github.com/JuliaLang/julia/issues/48038
Expand Down
31 changes: 31 additions & 0 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,37 @@ if VERSION < v"1.8.0-DEV.1016"
export chopprefix, chopsuffix
end

if VERSION < v"1.12.0-DEV.974" # contrib/commit-name.sh 2635dea

insertdims(A; dims) = _insertdims(A, dims)

function _insertdims(A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {T, N, M}
for i in eachindex(dims)
1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1"))
dims[i] ≤ N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added"))
for j = 1:i-1
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
end
end

# acc is a tuple, where the first entry is the final shape
# the second entry off acc is a counter for the axes of A
inds = Base._foldoneto((acc, i) ->
i ∈ dims
? ((acc[1]..., Base.OneTo(1)), acc[2])
: ((acc[1]..., axes(A, acc[2])), acc[2] + 1),
((), 1), Val(N+M))
new_shape = inds[1]
return reshape(A, new_shape)
end

_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),))

export insertdims
else
using Base: insertdims, _insertdims
end

# https://github.com/JuliaLang/julia/pull/54653: add Fix
@static if !isdefined(Base, :Fix) # VERSION < v"1.12.0-DEV.981"
@static if !isdefined(Base, :_stable_typeof)
Expand Down
31 changes: 31 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,37 @@ end
end
end

# https://github.com/JuliaLang/julia/pull/45793
@testset "insertdims" begin
a = rand(8, 7)
@test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7))
@test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7))
@test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7))
@test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 4, 6, 3))) == reshape(a, (1, 8, 1, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1))

@test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3))
@test_throws UndefKeywordError insertdims(a)
@test_throws ArgumentError insertdims(a, dims=0)
@test_throws ArgumentError insertdims(a, dims=(1, 2, 1))
@test_throws ArgumentError insertdims(a, dims=4)
@test_throws ArgumentError insertdims(a, dims=6)

# insertdims and dropdims are inverses
b = rand(1,1,1,5,1,1,7)
for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6), (1,6,5,3)]
@test dropdims(insertdims(a; dims); dims) == a
@test insertdims(dropdims(b; dims); dims) == b
end
end

# https://github.com/JuliaLang/julia/pull/54653: add Fix
@testset "Fix" begin
function test_fix1(Fix1=Compat.Fix1)
Expand Down
Loading