Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Try using an advanced indices object that wraps axes #81

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
WIP: use IndexAxis for reductions
This _almost_ works, but it's not type stable. The trouble is that the
`IndexAxis` type _sometimes_ wants to know the exact Axis information (like
when used in `similar`), but other times it only cares about the Axis name
(like when choosing the dimensions for a reduction).  Specifically, this branch
runs into trouble with this base method:

```
    function reduced_indices(inds::Indices{N}, d::Int, rd::AbstractUnitRange) where N
        d < 1 && throw(ArgumentError("dimension must be ≥ 1, got $d"))
        if d == 1
            return (oftype(inds[1], rd), tail(inds)...)
        elseif 1 < d <= N
            return tuple(inds[1:d-1]..., oftype(inds[d], rd), inds[d+1:N]...)::typeof(inds)
```

I've broken the contract that `oftype`—that is, `convert`—returns an object
of exactly the requested type.  But here I simply want a new IndexAxis object
that has the same name and wraps the given `rd` range.
  • Loading branch information
mbauman committed Apr 30, 2017
commit 564f4b7c9fbb313f0a66c0837fbb52613f8bb985
65 changes: 23 additions & 42 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ immutable IndexAxis{I,A} <: AbstractUnitRange{Int}
index::I
axis::A
end
@inline Base.convert{I,name,T}(::Type{IndexAxis{I,Axis{name,T}}}, index::AbstractUnitRange) = IndexAxis(index, Axis{name}(index))
@inline Base.indices(I::IndexAxis) = indices(I.index)
@inline Base.unsafe_indices(I::IndexAxis) = Base.unsafe_indices(I.index)
@inline Base.indices1(I::IndexAxis) = Base.indices1(I.index)
Expand Down Expand Up @@ -342,48 +343,6 @@ function Base.similar(A::AbstractArray{T}, shape::AxisDims) where T
AxisArray(similar(A, T, map(_ensure_index, axs)), axs)
end

# These methods allow us to preserve the AxisArray under reductions
# Note that we only extend the following two methods, and then have it
# dispatch to package-local `reduced_indices` and `reduced_indices0`
# methods. This avoids a whole slew of ambiguities.
if VERSION == v"0.5.0"
Base.reduced_dims(A::AxisArray, region) = reduced_indices(axes(A), region)
Base.reduced_dims0(A::AxisArray, region) = reduced_indices0(axes(A), region)
else
Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region)
Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region)
end

reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
reduced_indices(axs, (region,))
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
reduced_indices0(axs, (region,))

reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
map((ax,d)->d∈region ? reduced_axis(ax) : ax, axs, ntuple(identity, Val{N}))
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
map((ax,d)->d∈region ? reduced_axis0(ax) : ax, axs, ntuple(identity, Val{N}))

@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
_reduced_indices(reduced_axis, (), region, axs...)
@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
_reduced_indices(reduced_axis0, (), region, axs...)
@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) =
_reduced_indices(reduced_axis, (), region, axs...)
@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) =
_reduced_indices(reduced_axis0, (), region, axs...)

reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) =
reduced_indices(reduced_indices(axs, region[1]), tail(region))
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
reduced_indices(reduced_indices(axs, region[1]), tail(region))
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) =
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))

@pure samesym{n1,n2}(::Type{Axis{n1}}, ::Type{Axis{n2}}) = Val{n1==n2}()
samesym{n1,n2,T1,T2}(::Type{Axis{n1,T1}}, ::Type{Axis{n2,T2}}) = samesym(Axis{n1},Axis{n2})
samesym{n1,n2}(::Type{Axis{n1}}, ::Axis{n2}) = samesym(Axis{n1}, Axis{n2})
Expand Down Expand Up @@ -485,6 +444,28 @@ end
@inline dropax{name,T}(ax::Type{Axis{name,T}}, ax1::Axis{name}, axs...) = dropax(ax, axs...)
dropax(ax) = ()

# Reductions: Support specifying the reduction in terms of Axis{:name} or Axis{:name}()
const _AxTyp = Union{@compat(Type{<:Axis}), @compat(Axis{<:Any, Tuple{}})}
if VERSION == v"0.5.0"
Base.reduced_dims{N}(A::AxisArray, region::Union{_AxTyp, Tuple{_AxTyp, Vararg{_AxTyp}}}) =
Base.reduced_dims(A, findax(indices(A), region))
Base.reduced_dims0{N}(A::AxisArray, region::Union{_AxTyp, Tuple{_AxTyp, Vararg{_AxTyp}}}) =
Base.reduced_dims(A, findax(indices(A), region))
else
Base.reduced_indices{N}(inds::Base.Indices{N}, region::Union{_AxTyp, Tuple{_AxTyp, Vararg{_AxTyp}}}) =
Base.reduced_indices(inds, findax(inds, region))
Base.reduced_indices0{N}(inds::Base.Indices{N}, region::Union{_AxTyp, Tuple{_AxTyp, Vararg{_AxTyp}}}) =
Base.reduced_indices0(inds, findax(inds, region))
end
findax(inds, region) = _findax(1, inds, region)
findax(inds, region::Tuple) = map(x->findax(inds, x), region)
_findax(dim, inds::Tuple{IndexAxis, Vararg{Any}}, region) =
axisname(inds[1].axis) == axisname(region) ? dim : _findax(dim+1, tail(inds), region)
_findax(dim, inds::Tuple{Axis, Vararg{Any}}, region) =
axisname(inds[1]) == axisname(region) ? dim : _findax(dim+1, tail(inds), region)
_findax(dim, inds::Tuple{Any, Vararg{Any}}, region) =
_defaultdimname(dim) == axisname(region) ? dim : _findax(dim+1, tail(inds), region)
_findax(dim, ::Tuple{}, region) = throw(ArgumentError("Axis $region not found"))

# A simple display method to include axis information. It might be nice to
# eventually display the axis labels alongside the data array, but that is
Expand Down