Skip to content

map and mapreduce #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
show, view, in
show, view, in, mapreduce

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec
Expand Down
80 changes: 79 additions & 1 deletion src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,84 @@

map(f::Function, r::AbstractFill) = Fill(f(getindex_value(r)), axes(r))

function map(f::Function, vs::AbstractFill{<:Any,1}...)
stop = mapreduce(length, min, vs)
val = f(map(getindex_value, vs)...)
Fill(val, stop)
end

function map(f::Function, rs::AbstractFill...)
if _maplinear(rs...)
map(f, map(vec, rs)...)
else
val = f(map(getindex_value, rs)...)
Fill(val, axes(first(rs)))
end
end

function _maplinear(rs...) # tries to match Base's behaviour, could perhaps hook in more deeply
if any(ndims(r)==1 for r in rs)
return true
else
r1 = axes(first(rs))
for r in rs
axes(r) == r1 || throw(DimensionMismatch(
"dimensions must match: a has dims $r1, b has dims $(axes(r))"))
end
return false
end
end

### mapreduce

if VERSION >= v"1.4"
# _InitialValue was introduced after 1.0, before 1.4, not sure exact version.
# Without these methods, some reductions will give an Array not a Fill.

function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon)
fval = f(getindex_value(A))
out = fval
for _ in 2:length(A)
out = op(out, fval)
end
out
end

function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims)
fval = f(getindex_value(A))
red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...)
out = fval
for _ in 2:red
out = op(out, fval)
end
Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A)))
end

end
if VERSION >= v"1.2" # Vararg mapreduce was added in Julia 1.2

function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...)
val(_...) = f(getindex_value(A), getindex_value(B))
reduce(op, map(val, A, B); kw...)
end

# These are particularly useful because mapreduce(*, +, A, B; dims) is slow in Base,
# but can be re-written as some mapreduce(g, +, C; dims) which is fast.

function mapreduce(f, op, A::AbstractFill, B::AbstractArray, Cs::AbstractArray...; kw...)
g(b, cs...) = f(getindex_value(A), b, cs...)
mapreduce(g, op, B, Cs...; kw...)
end
function mapreduce(f, op, A::AbstractArray, B::AbstractFill, Cs::AbstractArray...; kw...)
h(a, cs...) = f(a, getindex_value(B), cs...)
mapreduce(h, op, A, Cs...; kw...)
end
function mapreduce(f, op, A::AbstractFill, B::AbstractFill, Cs::AbstractArray...; kw...)
gh(cs...) = f(getindex_value(A), getindex_value(B), cs...)
mapreduce(gh, op, Cs...; kw...)
end

end

### Unary broadcasting

Expand Down Expand Up @@ -165,4 +243,4 @@ end
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), axes(r))
broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), axes(r))
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), axes(r))
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), axes(r))
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), axes(r))
63 changes: 56 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ end

@testset "Cumsum and diff" begin
@test sum(Fill(3,10)) ≡ 30
@test reduce(+, Fill(3,10)) ≡ 30
@test sum(x -> x + 1, Fill(3,10)) ≡ 40
@test cumsum(Fill(3,10)) ≡ 3:3:30

Expand Down Expand Up @@ -758,15 +759,63 @@ end
end

@testset "map" begin
x = Ones(5)
@test map(exp,x) === Fill(exp(1.0),5)
@test map(isone,x) === Fill(true,5)
x1 = Ones(5)
@test map(exp,x1) === Fill(exp(1.0),5)
@test map(isone,x1) === Fill(true,5)

x = Zeros(5)
@test map(exp,x) === exp.(x)
x0 = Zeros(5)
@test map(exp,x0) === exp.(x0)

x = Fill(2,5,3)
@test map(exp,x) === Fill(exp(2),5,3)
x2 = Fill(2,5,3)
@test map(exp,x2) === Fill(exp(2),5,3)

@test map(+, x1, x2) === Fill(3.0, 5)
@test map(+, x2, x2) === x2 .+ x2
@test_throws DimensionMismatch map(+, x2', x2)
end

@testset "mapreduce" begin
x = rand(3, 4)
y = fill(1.0, 3, 4)
Y = Fill(1.0, 3, 4)
O = Ones(3, 4)

@test mapreduce(exp, +, Y) == mapreduce(exp, +, y)
@test mapreduce(exp, +, Y; dims=2) == mapreduce(exp, +, y; dims=2)
@test mapreduce(identity, +, Y) == sum(y) == sum(Y)
@test mapreduce(identity, +, Y, dims=1) == sum(y, dims=1) == sum(Y, dims=1)

if VERSION >= v"1.4"
@test mapreduce(exp, +, Y; dims=(1,), init=5.0) == mapreduce(exp, +, y; dims=(1,), init=5.0)
end

if VERSION >= v"1.2" # Vararg mapreduce was added in Julia 1.2

# Two arrays
@test mapreduce(*, +, x, Y) == mapreduce(*, +, x, y)
@test mapreduce(*, +, Y, x) == mapreduce(*, +, y, x)
@test mapreduce(*, +, x, O) == mapreduce(*, +, x, y)
@test mapreduce(*, +, Y, O) == mapreduce(*, +, y, y)

f2(x,y) = 1 + x/y
op2(x,y) = x^2 + 3y
@test mapreduce(f2, op2, x, Y) == mapreduce(f2, op2, x, y)

if VERSION >= v"1.4"
@test mapreduce(f2, op2, x, Y, dims=1, init=5.0) == mapreduce(f2, op2, x, y, dims=1, init=5.0)
@test mapreduce(f2, op2, Y, x, dims=1, init=5.0) == mapreduce(f2, op2, y, x, dims=1, init=5.0)
@test mapreduce(f2, op2, x, O, dims=1, init=5.0) == mapreduce(f2, op2, x, y, dims=1, init=5.0)
@test mapreduce(f2, op2, Y, O, dims=1, init=5.0) == mapreduce(f2, op2, y, y, dims=1, init=5.0)
end

# More than two
@test mapreduce(+, +, x, Y, x) == mapreduce(+, +, x, y, x)
@test mapreduce(+, +, Y, x, x) == mapreduce(+, +, y, x, x)
@test mapreduce(+, +, x, O, Y) == mapreduce(+, +, x, y, y)
@test mapreduce(+, +, Y, O, Y) == mapreduce(+, +, y, y, y)
@test mapreduce(+, +, Y, O, Y, x) == mapreduce(+, +, y, y, y, x)

end
end

@testset "Offset indexing" begin
Expand Down