diff --git a/src/FLoops.jl b/src/FLoops.jl index 4358b886c..3a12572c9 100644 --- a/src/FLoops.jl +++ b/src/FLoops.jl @@ -20,7 +20,7 @@ module FLoops doc end FLoops -export @floop, @init, @reduce, DistributedEx, SequentialEx, ThreadedEx +export @floop, @init, @combine, @reduce, DistributedEx, SequentialEx, ThreadedEx using BangBang.Extras: broadcast_inplace!! using BangBang: materialize!!, push!! diff --git a/src/reduce.jl b/src/reduce.jl index 807a7efe0..29ef514cb 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -86,13 +86,38 @@ macro reduce(args...) end # TODO: detect free variables in `do` blocks -struct ReduceOpSpec +abstract type OpSpec end + +struct ReduceOpSpec <: OpSpec args::Vector{Any} visible::Vector{Symbol} end ReduceOpSpec(args::Vector{Any}) = ReduceOpSpec(args, Symbol[]) +""" + @combine() do (acc₁ [= init₁]; x₁), ..., (accₙ [= initₙ]; xₙ) + ... + end + @combine(acc₁ ⊗₁= x₁, ..., accₙ ⊗ₙ= xₙ) + @combine(acc₁ .⊗₁= x₁, ..., accₙ .⊗ₙ= xₙ) + @combine(acc₁ = ⊗₁(other₁, x₁), ..., accₙ = ⊗ₙ(otherₙ, xₙ)) + @combine(acc₁ .= (⊗₁).(other₁, x₁), ..., accₙ = (⊗ₙ).(otherₙ, xₙ)) + +Declare how accumulators from two basecases are combined. Unlike `@reduce`, the +reduction for the basecase is not defined by this macro. +""" +macro combine(args...) + :(throw($(CombineOpSpec(collect(Any, args))))) +end + +struct CombineOpSpec <: OpSpec + args::Vector{Any} + visible::Vector{Symbol} +end + +CombineOpSpec(args::Vector{Any}) = CombineOpSpec(args, Symbol[]) + """ @init begin pv₁ = init₁ @@ -148,7 +173,7 @@ end copyexpr(expr::Expr) = Expr(expr.head, (copyexpr(a) for a in expr.args)...) copyexpr(@nospecialize x) = x -function analyze_loop_local_variables!(spec::Union{ReduceOpSpec,InitSpec}, scopes) +function analyze_loop_local_variables!(spec::Union{OpSpec,InitSpec}, scopes) @assert isempty(spec.visible) append!(spec.visible, (var.name for sc in scopes for var in sc.bounds)) unique!(spec.visible) @@ -217,16 +242,18 @@ function unpack_kwargs(; otherwise = donothing, on_expr = otherwise, on_init = otherwise, + on_combine = otherwise, kwargs..., ) @assert isempty(kwargs) - return (otherwise, on_expr, on_init) + return (otherwise, on_expr, on_init, on_combine) end function on_reduce_op_spec(on_spec, ex; kwargs...) - (otherwise, on_expr, on_init) = unpack_kwargs(; kwargs...) + (otherwise, on_expr, on_init, on_combine) = unpack_kwargs(; kwargs...) @match ex begin Expr(:call, throw′, spec::ReduceOpSpec) => on_spec(spec) + Expr(:call, throw′, spec::CombineOpSpec) => on_combine(spec) Expr(:call, throw′, spec::InitSpec) => on_init(spec) Expr(head, args...) => begin new_args = map(args) do x @@ -238,8 +265,8 @@ function on_reduce_op_spec(on_spec, ex; kwargs...) end end -on_reduce_op_spec_reconstructing(on_spec, ex; otherwise = identity, on_init = otherwise) = - on_reduce_op_spec(on_spec, ex; on_expr = Expr, otherwise = otherwise, on_init = on_init) +on_reduce_op_spec_reconstructing(on_spec, ex; kwargs...) = + on_reduce_op_spec(on_spec, ex; otherwise = identity, on_expr = Expr, kwargs...) function floop_parallel(ctx::MacroContext, ex::Expr, simd, executor = nothing) if !isexpr(ex, :for, 2) @@ -439,7 +466,7 @@ end ``` """ function process_reduce_op_spec( - spec::ReduceOpSpec, + spec::OpSpec, )::NamedTuple{(:inits, :accs, :inputs, :pre_updates, :initializers, :updaters)} opspecs = spec.args::Vector{Any} @@ -662,14 +689,26 @@ function as_parallel_loop(ctx::MacroContext, rf_arg, coll, body0::Expr, simd, ex $ScratchSpace($init_allocator, ($(accs...),)) end + @gensym value_field + allocated_value = quote + # Just in case the value is deallocated: + $grouped_private_states = $allocate($grouped_private_states) + # Get the `.value` field but let the compiler know that field is allocated: + let $value_field = $grouped_private_states.value + if $value_field isa $Cleared + $unreachable_floop() + else + $value_field + end + end + end + if isempty(intersect(spec.visible, unbound_rhs(spec.expr))) # Hoisting out `@init`, since it is not accessing variables used # inside the loop body. push!(init_exprs, scratch) return quote - # Just in case it is demoted to `EmptyScratchSpace`: - $grouped_private_states = $allocate($grouped_private_states) - ($(accs...),) = $grouped_private_states.value + ($(accs...),) = $allocated_value end else push!(init_exprs, _FLoopInit()) @@ -679,16 +718,54 @@ function as_parallel_loop(ctx::MacroContext, rf_arg, coll, body0::Expr, simd, ex # Assigning to `grouped_private_states` for reusing it next time. $grouped_private_states = $scratch else - # Just in case it is demoted to `EmptyScratchSpace`: - $grouped_private_states = $allocate($grouped_private_states) # After the initialization, just carry it over to the next iteration: - ($(accs...),) = $grouped_private_states.value + ($(accs...),) = $allocated_value end end end end - body1 = on_reduce_op_spec_reconstructing(body0; on_init = on_init) do spec + function on_combine(spec::CombineOpSpec) + @gensym grouped_accs grouped_inputs + push!(accs_symbols, grouped_accs) + push!(inputs_symbols, grouped_inputs) + (_inits, accs, inputs, pre_updates, _initializers, updaters) = + process_reduce_op_spec(spec) + # TODO: check `init` + # TODO: check `initializers` + + push!(is_init, false) + push!(init_exprs, _FLoopInit()) + push!(all_rf_inits, nothing) + push!(all_rf_accs, accs) + push!(all_rf_inputs, inputs) + verify_unique_symbols(accs, "accumulator") + verify_unique_symbols(inputs, "input") + + combine_body = quote + if $grouped_inputs isa $_FLoopInit + else + if $grouped_accs isa $_FLoopInit + else + ($(inputs...),) = $grouped_inputs + ($(accs...),) = $grouped_accs + $(updaters...) + end + $grouped_accs = ($(accs...),) + end + end + push!(combine_bodies, combine_body) + return quote + $(pre_updates...) + $grouped_accs = ($(inputs...),) + end + end + + body1 = on_reduce_op_spec_reconstructing( + body0; + on_init = on_init, + on_combine = on_combine, + ) do spec spec = spec::ReduceOpSpec @gensym grouped_accs grouped_inputs push!(accs_symbols, grouped_accs) @@ -858,6 +935,12 @@ function Base.showerror(io::IO, opspecs::ReduceOpSpec) print(io, ")` used outside `@floop`") end +function Base.showerror(io::IO, opspecs::CombineOpSpec) + print(io, "`@combine(") + join(io, opspecs.args, ", ") + print(io, ")` used outside `@floop`") +end + function Base.showerror(io::IO, spec::InitSpec) ex = spec.expr print(io, "`@init", ex, "` used outside `@floop`") diff --git a/test/FLoopsTests/src/test_combine.jl b/test/FLoopsTests/src/test_combine.jl new file mode 100644 index 000000000..9d1b8ac26 --- /dev/null +++ b/test/FLoopsTests/src/test_combine.jl @@ -0,0 +1,44 @@ +module TestCombine + +using FLoops +using MicroCollections +using Test + +function countmap_two_pass(indices, ex = nothing) + l, h = extrema(indices) + n = h - l + 1 + @floop ex for i in indices + @init b = zeros(Int, n) + b[i-l+1] += 1 + @combine h .+= b + end + return h +end + +function test_countmap_two_pass() + @testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)] + @test countmap_two_pass(1:3, ex) == [1, 1, 1] + @test countmap_two_pass([1, 2, 4, 1], ex) == [2, 1, 0, 1] + end +end + +#= +using FillArrays +function countmap_one_pass(indices) + @floop for i in indices + @init l = nothing + @init b = [0] + if l === nothing + l = i + elseif i < l + splice!(b, 1:0, Zeros(l - i + 1)) + l = i + end + b[i - l + 1] += 1 + @combine() do (h; b), (l; l2) + end + end +end +=# + +end # module