Skip to content

Commit

Permalink
Add @combine for more manual reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Jan 29, 2022
1 parent 047b179 commit f2f4f3c
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/FLoops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module FLoops
doc
end FLoops

export @floop, @init, @reduce, DistributedEx, SequentialEx, ThreadedEx
export @floop, @init, @combine, @reduce, DistributedEx, SequentialEx, ThreadedEx

using BangBang.Extras: broadcast_inplace!!
using BangBang: materialize!!, push!!
Expand Down
111 changes: 97 additions & 14 deletions src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,38 @@ macro reduce(args...)
end
# TODO: detect free variables in `do` blocks

struct ReduceOpSpec
abstract type OpSpec end

struct ReduceOpSpec <: OpSpec
args::Vector{Any}
visible::Vector{Symbol}
end

ReduceOpSpec(args::Vector{Any}) = ReduceOpSpec(args, Symbol[])

"""
@combine() do (acc₁ [= init₁]; x₁), ..., (accₙ [= initₙ]; xₙ)
...
end
@combine(acc₁ ⊗₁= x₁, ..., accₙ ⊗ₙ= xₙ)
@combine(acc₁ .⊗₁= x₁, ..., accₙ .⊗ₙ= xₙ)
@combine(acc₁ = ⊗₁(other₁, x₁), ..., accₙ = ⊗ₙ(otherₙ, xₙ))
@combine(acc₁ .= (⊗₁).(other₁, x₁), ..., accₙ = (⊗ₙ).(otherₙ, xₙ))
Declare how accumulators from two basecases are combined. Unlike `@reduce`, the
reduction for the basecase is not defined by this macro.
"""
macro combine(args...)
:(throw($(CombineOpSpec(collect(Any, args)))))
end

struct CombineOpSpec <: OpSpec
args::Vector{Any}
visible::Vector{Symbol}
end

CombineOpSpec(args::Vector{Any}) = CombineOpSpec(args, Symbol[])

"""
@init begin
pv₁ = init₁
Expand Down Expand Up @@ -148,7 +173,7 @@ end
copyexpr(expr::Expr) = Expr(expr.head, (copyexpr(a) for a in expr.args)...)
copyexpr(@nospecialize x) = x

function analyze_loop_local_variables!(spec::Union{ReduceOpSpec,InitSpec}, scopes)
function analyze_loop_local_variables!(spec::Union{OpSpec,InitSpec}, scopes)
@assert isempty(spec.visible)
append!(spec.visible, (var.name for sc in scopes for var in sc.bounds))
unique!(spec.visible)
Expand Down Expand Up @@ -217,16 +242,18 @@ function unpack_kwargs(;
otherwise = donothing,
on_expr = otherwise,
on_init = otherwise,
on_combine = otherwise,
kwargs...,
)
@assert isempty(kwargs)
return (otherwise, on_expr, on_init)
return (otherwise, on_expr, on_init, on_combine)
end

function on_reduce_op_spec(on_spec, ex; kwargs...)
(otherwise, on_expr, on_init) = unpack_kwargs(; kwargs...)
(otherwise, on_expr, on_init, on_combine) = unpack_kwargs(; kwargs...)
@match ex begin
Expr(:call, throw′, spec::ReduceOpSpec) => on_spec(spec)
Expr(:call, throw′, spec::CombineOpSpec) => on_combine(spec)
Expr(:call, throw′, spec::InitSpec) => on_init(spec)
Expr(head, args...) => begin
new_args = map(args) do x
Expand All @@ -238,8 +265,8 @@ function on_reduce_op_spec(on_spec, ex; kwargs...)
end
end

on_reduce_op_spec_reconstructing(on_spec, ex; otherwise = identity, on_init = otherwise) =
on_reduce_op_spec(on_spec, ex; on_expr = Expr, otherwise = otherwise, on_init = on_init)
on_reduce_op_spec_reconstructing(on_spec, ex; kwargs...) =
on_reduce_op_spec(on_spec, ex; otherwise = identity, on_expr = Expr, kwargs...)

function floop_parallel(ctx::MacroContext, ex::Expr, simd, executor = nothing)
if !isexpr(ex, :for, 2)
Expand Down Expand Up @@ -439,7 +466,7 @@ end
```
"""
function process_reduce_op_spec(
spec::ReduceOpSpec,
spec::OpSpec,
)::NamedTuple{(:inits, :accs, :inputs, :pre_updates, :initializers, :updaters)}
opspecs = spec.args::Vector{Any}

Expand Down Expand Up @@ -662,14 +689,26 @@ function as_parallel_loop(ctx::MacroContext, rf_arg, coll, body0::Expr, simd, ex
$ScratchSpace($init_allocator, ($(accs...),))
end

@gensym value_field
allocated_value = quote
# Just in case the value is deallocated:
$grouped_private_states = $allocate($grouped_private_states)
# Get the `.value` field but let the compiler know that field is allocated:
let $value_field = $grouped_private_states.value
if $value_field isa $Cleared
$unreachable_floop()
else
$value_field
end
end
end

if isempty(intersect(spec.visible, unbound_rhs(spec.expr)))
# Hoisting out `@init`, since it is not accessing variables used
# inside the loop body.
push!(init_exprs, scratch)
return quote
# Just in case it is demoted to `EmptyScratchSpace`:
$grouped_private_states = $allocate($grouped_private_states)
($(accs...),) = $grouped_private_states.value
($(accs...),) = $allocated_value
end
else
push!(init_exprs, _FLoopInit())
Expand All @@ -679,16 +718,54 @@ function as_parallel_loop(ctx::MacroContext, rf_arg, coll, body0::Expr, simd, ex
# Assigning to `grouped_private_states` for reusing it next time.
$grouped_private_states = $scratch
else
# Just in case it is demoted to `EmptyScratchSpace`:
$grouped_private_states = $allocate($grouped_private_states)
# After the initialization, just carry it over to the next iteration:
($(accs...),) = $grouped_private_states.value
($(accs...),) = $allocated_value
end
end
end
end

body1 = on_reduce_op_spec_reconstructing(body0; on_init = on_init) do spec
function on_combine(spec::CombineOpSpec)
@gensym grouped_accs grouped_inputs
push!(accs_symbols, grouped_accs)
push!(inputs_symbols, grouped_inputs)
(_inits, accs, inputs, pre_updates, _initializers, updaters) =
process_reduce_op_spec(spec)
# TODO: check `init`
# TODO: check `initializers`

push!(is_init, false)
push!(init_exprs, _FLoopInit())
push!(all_rf_inits, nothing)
push!(all_rf_accs, accs)
push!(all_rf_inputs, inputs)
verify_unique_symbols(accs, "accumulator")
verify_unique_symbols(inputs, "input")

combine_body = quote
if $grouped_inputs isa $_FLoopInit
else
if $grouped_accs isa $_FLoopInit
else
($(inputs...),) = $grouped_inputs
($(accs...),) = $grouped_accs
$(updaters...)
end
$grouped_accs = ($(accs...),)
end
end
push!(combine_bodies, combine_body)
return quote
$(pre_updates...)
$grouped_accs = ($(inputs...),)
end
end

body1 = on_reduce_op_spec_reconstructing(
body0;
on_init = on_init,
on_combine = on_combine,
) do spec
spec = spec::ReduceOpSpec
@gensym grouped_accs grouped_inputs
push!(accs_symbols, grouped_accs)
Expand Down Expand Up @@ -858,6 +935,12 @@ function Base.showerror(io::IO, opspecs::ReduceOpSpec)
print(io, ")` used outside `@floop`")
end

function Base.showerror(io::IO, opspecs::CombineOpSpec)
print(io, "`@combine(")
join(io, opspecs.args, ", ")
print(io, ")` used outside `@floop`")
end

function Base.showerror(io::IO, spec::InitSpec)
ex = spec.expr
print(io, "`@init", ex, "` used outside `@floop`")
Expand Down
44 changes: 44 additions & 0 deletions test/FLoopsTests/src/test_combine.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
module TestCombine

using FLoops
using MicroCollections
using Test

function countmap_two_pass(indices, ex = nothing)
l, h = extrema(indices)
n = h - l + 1
@floop ex for i in indices
@init b = zeros(Int, n)
b[i-l+1] += 1
@combine h .+= b
end
return h
end

function test_countmap_two_pass()
@testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)]
@test countmap_two_pass(1:3, ex) == [1, 1, 1]
@test countmap_two_pass([1, 2, 4, 1], ex) == [2, 1, 0, 1]
end
end

#=
using FillArrays
function countmap_one_pass(indices)
@floop for i in indices
@init l = nothing
@init b = [0]
if l === nothing
l = i
elseif i < l
splice!(b, 1:0, Zeros(l - i + 1))
l = i
end
b[i - l + 1] += 1
@combine() do (h; b), (l; l2)
end
end
end
=#

end # module

0 comments on commit f2f4f3c

Please sign in to comment.