Skip to content

Commit

Permalink
Propagate AxisArray copy / view down to taking copies / views of its …
Browse files Browse the repository at this point in the history
…axes as well.
  • Loading branch information
TechnophobicLampshade committed Sep 27, 2018
1 parent 5b1fd0e commit ac13645
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ Base.eachindex(A::AxisArray) = eachindex(A.data)
This internal function determines the new set of axes that are constructed upon
indexing with I.
"""
reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I)
reaxis(A::AxisArray, copy::Val, I::Idx...) = _reaxis(make_axes_match(axes(A), I), copy, I)
# Linear indexing
reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Int}) = _new_axes(A.axes[1], I)
reaxis(A::AxisArray, I::AbstractArray{Int}) = default_axes(I)
reaxis(A::AxisArray{<:Any,1}, I::Real) = ()
reaxis(A::AxisArray, I::Real) = ()
reaxis(A::AxisArray{<:Any,1}, I::Colon) = _new_axes(A.axes[1], Base.axes(A, 1))
reaxis(A::AxisArray, I::Colon) = default_axes(Base.OneTo(length(A)))
reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Bool}) = _new_axes(A.axes[1], findall(I))
reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I))
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::AbstractArray{Int}) = _new_axes(A.axes[1], copy, I)
reaxis(A::AxisArray, copy::Val, I::AbstractArray{Int}) = default_axes(I)
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::Real) = ()
reaxis(A::AxisArray, copy::Val, I::Real) = ()
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::Colon) = _new_axes(A.axes[1], copy, Base.axes(A, 1))
reaxis(A::AxisArray, copy::Val, I::Colon) = default_axes(Base.OneTo(length(A)))
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::AbstractArray{Bool}) = _new_axes(A.axes[1], copy, findall(I))
reaxis(A::AxisArray, copy::Val, I::AbstractArray{Bool}) = default_axes(findall(I))

# Ensure the number of axes matches the number of indexing dimensions
@inline function make_axes_match(axs, idxs)
Expand All @@ -66,28 +66,43 @@ reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I))
end

# Now we can reaxis without worrying about mismatched axes/indices
@inline _reaxis(axs::Tuple{}, idxs::Tuple{}) = ()
@inline _reaxis(axs::Tuple{}, copy::Val, idxs::Tuple{}) = ()
# Scalars are dropped
const ScalarIndex = Union{Real, AbstractArray{<:Any, 0}}
@inline _reaxis(axs::Tuple, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), tail(idxs))
@inline _reaxis(axs::Tuple, copy::Val, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), copy, tail(idxs))
# Colon passes straight through
@inline _reaxis(axs::Tuple, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), tail(idxs))...)
@inline _reaxis(axs::Tuple, copy::Val, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), copy, tail(idxs))...)
# But arrays can add or change dimensions and accompanying axis names
@inline _reaxis(axs::Tuple, idxs::Tuple{AbstractArray, Vararg{Any}}) =
(_new_axes(axs[1], idxs[1])..., _reaxis(tail(axs), tail(idxs))...)
@inline _reaxis(axs::Tuple, copy::Val, idxs::Tuple{AbstractArray, Vararg{Any}}) =
(_new_axes(axs[1], copy, idxs[1])..., _reaxis(tail(axs), copy, tail(idxs))...)

# Vectors simply create new axes with the same name; just subsetted by their value
@inline _new_axes(ax::Axis{name}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),)
@inline _new_axes(ax::Axis{name}, copy::Val{true}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),)
@inline _new_axes(ax::Axis{name}, copy::Val{false}, idx::AbstractVector) where {name} = (Axis{name}(view(ax.val, idx)),)

# @inline _new_axes(ax::Axis{name}, copy::Val{false}, idx::AxisArray{T,1,D,Ax}) where {Ax, D, T, name} = _new_axes(ax, copy, idx)

# Arrays create multiple axes with _N appended to the axis name containing their indices
@generated function _new_axes(ax::Axis{name}, idx::AbstractArray{<:Any,N}) where {name,N}
@generated function _new_axes(ax::Axis{name}, copy::Val, idx::AbstractArray{<:Any,N}) where {name, N}
newaxes = Expr(:tuple)
for i=1:N
push!(newaxes.args, :($(Axis{Symbol(name, "_", i)})(Base.axes(idx, $i))))
end
newaxes
end

# And indexing with an AxisArray joins the name and overrides the values
@generated function _new_axes(ax::Axis{name}, idx::AxisArray{<:Any, N}) where {name,N}
@generated function _new_axes(ax::Axis{name}, copy::Val{true}, idx::AxisArray{<:Any, N}) where {name,N}
newaxes = Expr(:tuple)
idxnames = axisnames(idx)
for i=1:N
push!(newaxes.args, :($(Axis{Symbol(name, "_", idxnames[i])})(idx.axes[$i].val)))
end
newaxes
end

# TODO: this is duplicated from the above
@generated function _new_axes(ax::Axis{name}, copy::Val{false}, idx::AxisArray{<:Any, N}) where {name,N}
newaxes = Expr(:tuple)
idxnames = axisnames(idx)
for i=1:N
Expand All @@ -97,19 +112,19 @@ end
end

@propagate_inbounds function Base.getindex(A::AxisArray, idxs::Idx...)
AxisArray(A.data[idxs...], reaxis(A, idxs...))
AxisArray(A.data[idxs...], reaxis(A, Val(true), idxs...))
end

# To resolve ambiguities, we need several definitions
using Base: AbstractCartesianIndex
@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, Val(false), idxs...))

# Setindex is so much simpler. Just assign it to the data:
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)

# Logical indexing
@propagate_inbounds function Base.getindex(A::AxisArray, idx::AbstractArray{Bool})
AxisArray(A.data[idx], reaxis(A, idx))
AxisArray(A.data[idx], reaxis(A, Val(true), idx))
end
@propagate_inbounds Base.setindex!(A::AxisArray, v, idx::AbstractArray{Bool}) = (A.data[idx] = v)

Expand Down

0 comments on commit ac13645

Please sign in to comment.