Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize eachslice #462

Merged
merged 22 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ jobs:
- windows-latest
arch:
- x64
include:
- version: 'nightly'
os: ubuntu-latest
arch: x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DimensionalData"
uuid = "0703355e-b756-11e9-17c0-8b28908087d0"
authors = ["Rafael Schouten <rafaelschouten@gmail.com>"]
version = "0.24.3"
version = "0.24.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ reorder
Base.cat
Base.map
Base.copy!
Base.eachslice
```

Most base methods work as expected, using `Dimension` wherever a `dims`
Expand Down
54 changes: 44 additions & 10 deletions src/array/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,51 @@ function Base.mapslices(f, A::AbstractDimArray; dims=1, kw...)
rebuild(A, data)
end

# This is copied from base as we can't efficiently wrap this function
# through the kw with a rebuild in the generator. Doing it this way
# also makes it faster to use a dim than an integer.
function Base.eachslice(A::AbstractDimArray; dims=1, kw...)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
if dims isa Tuple && length(dims) != 1
throw(ArgumentError("only single dimensions are supported"))
@static if VERSION < v"1.9-alpha1"
"""
Base.eachslice(A::AbstractDimArray; dims)

Create a generator that iterates over dimensions `dims` of `A`, returning arrays that
select all the data from the other dimensions in `A` using views.

The generator has `size` and `axes` equivalent to those of the provided `dims`.
"""
function Base.eachslice(A::AbstractDimArray; dims)
dimtuple = _astuple(dims)
all(hasdim(A, dimtuple...)) || throw(DimensionMismatch("A doesn't have all dimensions $dims"))
_eachslice(A, dimtuple)
end
dim = first(dimnum(A, dims))
dim <= ndims(A) || throw(DimensionMismatch("A doesn't have $dim dimensions"))
idx1, idx2 = ntuple(d->(:), dim-1), ntuple(d->(:), ndims(A)-dim)
return (view(A, idx1..., i, idx2...) for i in axes(A, dim))
else
@inline function Base.eachslice(A::AbstractDimArray; dims, drop=true)
dimtuple = _astuple(dims)
all(hasdim(A, dimtuple...)) || throw(DimensionMismatch("A doesn't have all dimensions $dims"))
_eachslice(A, dimtuple, drop)
end
Base.@constprop :aggressive function _eachslice(A::AbstractDimArray{T,N}, dims, drop) where {T,N}
slicedims = Dimensions.dims(A, dims)
Adims = Dimensions.dims(A)
if drop
ax = map(dim -> axes(A, dim), slicedims)
slicemap = map(Adims) do dim
hasdim(slicedims, dim) ? dimnum(slicedims, dim) : (:)
end
return Slices(A, slicemap, ax)
else
ax = map(Adims) do dim
hasdim(slicedims, dim) ? axes(A, dim) : axes(reducedims(dim, dim), 1)
end
slicemap = map(Adims) do dim
hasdim(slicedims, dim) ? dimnum(A, dim) : (:)
end
return Slices(A, slicemap, ax)
end
end
end

# works for arrays and for stacks
function _eachslice(x, dims::Tuple)
slicedims = Dimensions.dims(x, dims)
return (view(x, d...) for d in DimIndices(slicedims))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so succinct 😍

end

# Duplicated dims
Expand Down
42 changes: 42 additions & 0 deletions src/stack/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,48 @@ Other values will be returned in a `NamedTuple`.
Base.map(f, s::AbstractDimStack...) = _maybestack(s[1], map(f, map(NamedTuple, s)...))
Base.map(f, s::Union{AbstractDimStack,NamedTuple}...) = _maybestack(_firststack(s...), map(f, map(NamedTuple, s)...))

"""
Base.eachslice(stack::AbstractDimStack; dims)

Create a generator that iterates over dimensions `dims` of `stack`, returning stacks that
select all the data from the other dimensions in `stack` using views.

The generator has `size` and `axes` equivalent to those of the provided `dims`.

# Examples

```jldoctest; filter = r"┌ Warning:.*\\n.*"
julia> ds = DimStack((
x=DimArray(randn(2, 3, 4), (X([:x1, :x2]), Y(1:3), Z)),
y=DimArray(randn(2, 3, 5), (X([:x1, :x2]), Y(1:3), Ti))
));

julia> slices = eachslice(ds; dims=(Z, X));

julia> size(slices)
(4, 2)

julia> map(dims, axes(slices))
Z,
X Categorical{Symbol} Symbol[x1, x2] ForwardOrdered

julia> first(slices)
┌ Warning: (Z,) dims were not found in object
└ @ DimensionalData.Dimensions
DimStack with dimensions:
Y Sampled{Int64} 1:3 ForwardOrdered Regular Points,
Ti
and 2 layers:
:x Float64 dims: Y (3)
:y Float64 dims: Y, Ti (3×5)
```
"""
function Base.eachslice(s::AbstractDimStack; dims)
dimtuple = _astuple(dims)
all(hasdim(s, dimtuple...)) || throw(DimensionMismatch("s doesn't have all dimensions $dims"))
_eachslice(s, dimtuple)
end

_maybestack(s::AbstractDimStack, x::NamedTuple) = x
function _maybestack(
s::AbstractDimStack, das::NamedTuple{K,<:Tuple{Vararg{<:AbstractDimArray}}}
Expand Down
86 changes: 74 additions & 12 deletions test/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,84 @@ end
ti = Ti(1:4)
da = DimArray(a, (y, ti))
ys = (1, Y, Y(), :Y, y)
ys2 = (ys..., map(tuple, ys)...)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
tis = (2, Ti, Ti(), :Ti, ti)
for dims in tis
@test [mean(s) for s in eachslice(da; dims)] == [3.0, 4.0, 5.0, 6.0]
slices = [s .* 2 for s in eachslice(da; dims=Ti)]
@test slices[1] == [2, 6, 10]
@test DimensionalData.dims(slices[1]) == (Y(10.0:10.0:30.0),)
tis2 = (tis..., map(tuple, tis)...)

@testset "type-inferrable due to const-propagation" begin
f(x, dims) = eachslice(x; dims=dims)
f2(x, dims) = eachslice(x; dims=dims, drop=false)
@testset for dims in (y, ti, (y,), (ti,), (y, ti), (ti, y))
@inferred f(da, dims)
VERSION ≥ v"1.9-alpha1" && @inferred f2(da, dims)
end
end
for dims in ys
slices = [s .* 2 for s in eachslice(da; dims=Y)]
@test slices[1] == [2, 4, 6, 8]
@test slices[2] == [6, 8, 10, 12]
@test slices[3] == [10, 12, 14, 16]
@test DimensionalData.dims(slices[1]) == (Ti(1.0:1.0:4.0),)

@testset "error thrown if dimensions invalid" begin
@test_throws DimensionMismatch eachslice(da; dims=3)
@test_throws DimensionMismatch eachslice(da; dims=X)
@test_throws DimensionMismatch eachslice(da; dims=(4,))
@test_throws DimensionMismatch eachslice(da; dims=(Z,))
@test_throws DimensionMismatch eachslice(da; dims=(y, ti, Z))
end

@testset "slice over last dimension" begin
@testset for dims in tis2
da2 = map(mean, eachslice(da; dims=dims)) == DimArray([3.0, 4.0, 5.0, 6.0], ti)
slices = map(x -> x*2, eachslice(da; dims=dims))
@test slices isa DimArray
@test Dimensions.dims(slices) == (ti,)
@test slices[1] == DimArray([2, 6, 10], y)
if VERSION ≥ v"1.9-alpha1"
@test eachslice(da; dims=dims) isa Slices
slices = eachslice(da; dims=dims, drop=false)
@test slices isa Slices
@test slices == eachslice(parent(da); dims=dimnum(da, dims), drop=false)
@test axes(slices) == axes(sum(da; dims=otherdims(da, Dimensions.dims(da, dims))))
@test slices[1] == DimArray([1, 3, 5], y)
end
end
end

@testset "slice over first dimension" begin
@testset for dims in ys2
slices = map(x -> x*2, eachslice(da; dims=dims))
@test slices isa DimArray
@test Dimensions.dims(slices) == (y,)
@test slices[1] == DimArray([2, 4, 6, 8], ti)
@test slices[2] == DimArray([6, 8, 10, 12], ti)
@test slices[3] == DimArray([10, 12, 14, 16], ti)
if VERSION ≥ v"1.9-alpha1"
@test eachslice(da; dims=dims) isa Slices
slices = eachslice(da; dims=dims, drop=false)
@test slices isa Slices
@test slices == eachslice(parent(da); dims=dimnum(da, dims), drop=false)
@test axes(slices) == axes(sum(da; dims=otherdims(da, Dimensions.dims(da, dims))))
@test slices[1] == DimArray([1, 2, 3, 4], ti)
end
end
end

@test_throws ArgumentError [s .* 2 for s in eachslice(da; dims=(Y, Ti))]
@testset "slice over all permutations of both dimensions" begin
@testset for dims in Iterators.flatten((Iterators.product(ys, tis), Iterators.product(tis, ys)))
# mixtures of integers and dimensions are not supported
rem(sum(d -> isa(d, Int), dims), length(dims)) == 0 || continue
slices = map(x -> x*3, eachslice(da; dims=dims))
@test slices isa DimArray
@test Dimensions.dims(slices) == Dimensions.dims(da, dims)
@test size(slices) == map(x -> size(da, x), dims)
@test axes(slices) == map(x -> axes(da, x), dims)
@test eltype(slices) <: DimArray{Int, 0}
@test map(first, slices) == permutedims(da * 3, dims)
if VERSION ≥ v"1.9-alpha1"
@test eachslice(da; dims=dims) isa Slices
slices = eachslice(da; dims=dims, drop=false)
@test slices isa Slices
@test slices == eachslice(parent(da); dims=dimnum(da, dims), drop=false)
@test axes(slices) == axes(sum(da; dims=otherdims(da, Dimensions.dims(da, dims))))
end
end
end
end

@testset "simple dimension permuting methods" begin
Expand Down
67 changes: 67 additions & 0 deletions test/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,73 @@ end
@test s_cat[:one] == cat(parent(s[:one]), parent(s2[:one]); dims=1)
end

@testset "eachslice" begin
xs = (X, :X, x, 1)
xs2 = (xs..., map(tuple, xs)...)
ys = (Y, :Y, y, 2)
ys2 = (ys..., map(tuple, ys)...)
zs = (Z, :Z, z, 3)
zs2 = (zs..., map(tuple, zs)...)

@testset "type-inferrable due to const-propagation" begin
f(x, dims) = eachslice(x; dims=dims)
@testset for dims in (x, y, z, (x,), (y,), (z,), (x, y), (y, z), (x, y, z))
@inferred f(mixed, dims)
end
end

@testset "error thrown if dimensions invalid" begin
@test_throws DimensionMismatch eachslice(mixed; dims=4)
@test_throws DimensionMismatch eachslice(mixed; dims=Ti)
@test_throws DimensionMismatch eachslice(mixed; dims=Dim{:x})
end

@testset "slice over X dimension" begin
@testset for dims in xs2
@test eachslice(mixed; dims=dims) isa Base.Generator
slices = map(identity, eachslice(mixed; dims=dims))
@test slices isa DimArray{<:DimStack,1}
slices2 = map(l -> view(mixed, X(At(l))), lookup(Dimensions.dims(mixed, x)))
@test slices == slices2
end
end

@testset "slice over Y dimension" begin
@testset for dims in ys2
@test eachslice(mixed; dims=dims) isa Base.Generator
slices = map(identity, eachslice(mixed; dims=dims))
@test slices isa DimArray{<:DimStack,1}
slices2 = map(l -> view(mixed, Y(At(l))), lookup(y))
@test slices == slices2
end
end

@testset "slice over Z dimension" begin
@testset for dims in zs2
@test eachslice(mixed; dims=dims) isa Base.Generator
slices = map(identity, eachslice(mixed; dims=dims))
@test slices isa DimArray{<:DimStack,1}
slices2 = map(l -> view(mixed, Z(l)), axes(mixed, z))
@test slices == slices2
end
end

@testset "slice over combinations of Z and Y dimensions" begin
@testset for dims in Iterators.product(zs, ys)
# mixtures of integers and dimensions are not supported
rem(sum(d -> isa(d, Int), dims), length(dims)) == 0 || continue
@test eachslice(mixed; dims=dims) isa Base.Generator
slices = map(identity, eachslice(mixed; dims=dims))
@test slices isa DimArray{<:DimStack,2}
slices2 = map(
l -> view(mixed, Z(l[1]), Y(l[2])),
Iterators.product(axes(mixed, z), axes(mixed, y)),
)
@test slices == slices2
end
end
end

@testset "map" begin
@test values(map(a -> a .* 2, s)) == values(DimStack(2da1, 2da2, 2da3))
@test dims(map(a -> a .* 2, s)) == dims(DimStack(2da1, 2da2, 2da3))
Expand Down