diff --git a/docs/src/reference/api.md b/docs/src/reference/api.md index 998db2f4d..6acbfc63d 100644 --- a/docs/src/reference/api.md +++ b/docs/src/reference/api.md @@ -24,6 +24,12 @@ FLoops.@combine FLoops.@init ``` +## `@completebasecase` + +```@docs +FLoops.@completebasecase +``` + ## [`SequentialEx`, `ThreadedEx` and `DistributedEx` executors](@id executor) An *executor* controls how a given `@floop` is executed. FLoops.jl re-exports diff --git a/src/FLoops.jl b/src/FLoops.jl index 845538afd..048326a39 100644 --- a/src/FLoops.jl +++ b/src/FLoops.jl @@ -20,7 +20,16 @@ module FLoops doc end FLoops -export @floop, @init, @combine, @reduce, DistributedEx, SequentialEx, ThreadedEx +#! format: off +export @floop, + @init, + @combine, + @reduce, + @completebasecase, + DistributedEx, + SequentialEx, + ThreadedEx +#! format: on using BangBang.Extras: broadcast_inplace!! using BangBang: materialize!!, push!! diff --git a/src/combine.jl b/src/combine.jl index e09830b3d..6f8a6b490 100644 --- a/src/combine.jl +++ b/src/combine.jl @@ -33,9 +33,104 @@ struct CombineOpSpec <: OpSpec end CombineOpSpec(args::Vector{Any}) = CombineOpSpec(args, Symbol[]) +macroname(::CombineOpSpec) = Symbol("@combine") + +# Without a macro like `@completebasecase`, it'd be confusing to have an +# expression such as +# +# @floop begin +# ... +# for x in xs +# ... # executed in parallel loop body +# end +# for y in ys # executed in completebasecase hook +# ... +# end +# ... +# end +# +# i.e., two similar loops have drastically different semantics. The difference +# can be clarified by using the syntax: +# +# @floop begin +# ... +# for x in xs +# ... # executed in parallel loop body +# end +# @completebasecase begin +# for y in ys # executed in completebasecase hook +# ... +# end +# end +# ... +# end +""" + @completebasecase ex + +Evaluate expression `ex` at the end of each basecase. The expression `ex` can +only refer to the variables declared by `@init`. + +`@completebasecase` can be omitted if `ex` does not contain a `for` loop. + +# Examples +```jldoctest +julia> using FLoops + +julia> pidigits = string(BigFloat(π; precision = 2^20))[3:end]; + +julia> @floop begin + @init hist = zeros(Int, 10) + for c in pidigits + i = c - '0' + 1 + hist[i] += 1 + end + @completebasecase begin + j = 0 + y = 0 + for (i, x) in pairs(hist) # pretending we don't have `argmax` + if x > y + j = i + y = x + end + end + peaks = [j] + nchunks = [sum(hist)] + end + @combine hist .+= _ + @combine peaks = append!(_, _) + @combine nchunks = append!(_, _) + end +``` +""" +macro completebasecase(ex) + ex = Expr(:block, __source__, ex) + :(throw($(CompleteBasecaseOp(ex)))) +end + +struct CompleteBasecaseOp + ex::Expr +end + +function extract_spec(ex) + @match ex begin + Expr(:call, throw′, spec::ReduceOpSpec) => spec + Expr(:call, throw′, spec::CombineOpSpec) => spec + Expr(:call, throw′, spec::InitSpec) => spec + Expr(:call, throw′, spec::CompleteBasecaseOp) => spec + _ => nothing + end +end + +isa_spec(::Type{T}) where {T} = x -> extract_spec(x) isa T function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = nothing) - iterspec, body, ansvar, pre, post = destructure_loop_pre_post(ex) + iterspec, body, ansvar, pre, post = destructure_loop_pre_post( + ex; + multiple_loop_note = string( + " Wrap the expressions after the first loop (parallel loop) with", + " `@completebasecase`.", + ), + ) @assert ansvar == :_ parallel_loop_ex = @match iterspec begin @@ -50,15 +145,6 @@ function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = not return parallel_loop_ex end -function extract_spec(ex) - @match ex begin - Expr(:call, throw′, spec::ReduceOpSpec) => spec - Expr(:call, throw′, spec::CombineOpSpec) => spec - Expr(:call, throw′, spec::InitSpec) => spec - _ => nothing - end -end - function as_parallel_combine_loop( ctx::MacroContext, pre::Vector, @@ -70,6 +156,7 @@ function as_parallel_combine_loop( executor, ) @assert simd in (false, true, :ivdep) + foreach(disalow_raw_for_loop_without_completebasecase, post) init_exprs = [] all_rf_accs = [] @@ -89,11 +176,27 @@ function as_parallel_combine_loop( # `next` reducing step function: base_accs = mapcat(identity, all_rf_accs) - firstcombine = something( - findfirst(x -> extract_spec(x) isa CombineOpSpec, post), - lastindex(post) + 1, - ) + firstcombine = something(findfirst(isa_spec(CombineOpSpec), post), lastindex(post) + 1) + completebasecase_exprs = post[firstindex(post):firstcombine-1] + if any(isa_spec(CompleteBasecaseOp), completebasecase_exprs) + # If `CompleteBasecaseOp` is used, this must be the only expression: + let exprs = [x for x in completebasecase_exprs if !(x isa LineNumberNode)], + spec = extract_spec(exprs[1]) + + if spec isa CompleteBasecaseOp && length(exprs) == 1 + completebasecase_exprs = Any[spec.ex] + elseif all(isa_spec(CompleteBasecaseOp), exprs) + error("Only one `@completebasecase` can be used. got:\n", join(exprs, "\n")) + else + error( + "`@completebasecase` cannot be mixed with other expressions.", + " Put everything in `@completebasecase begin ... end`. got:\n", + join(exprs, "\n"), + ) + end + end + end left_accs = [] right_accs = [] @@ -104,7 +207,8 @@ function as_parallel_combine_loop( spec = extract_spec(ex) if !(spec isa CombineOpSpec) error( - "non-`@combine` expressions must be placed between `for` loop and the first `@combine` expression: ", + "non-`@combine` expressions must be placed between `for` loop and the", + " first `@combine` expression: ", spec, ) end @@ -279,3 +383,22 @@ function process_combine_op_spec( # TODO: use accurate line number from `@combine` return (; left = left, right = right, combine_body = combine_body) end + +function disalow_raw_for_loop_without_completebasecase(@nospecialize(ex)) + ex isa Expr || return + extract_spec(ex) === nothing || return + _disalow_raw_for_loop(ex) +end + +function _disalow_raw_for_loop(@nospecialize(ex)) + ex isa Expr || return + if isexpr(ex, :for) + error( + "`@floop begin ... end` can only contain one `for` loop.", + " Use `@completebasecase begin ... end` to wrap the code after the parallel", + " loop, including the `for` loop. Got:\n", + ex, + ) + end + foreach(_disalow_raw_for_loop, ex.args) +end diff --git a/src/macro.jl b/src/macro.jl index 118a8adc8..649d57fc9 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -69,7 +69,7 @@ end Goto{label}(acc::T) where {label,T} = Goto{label,T}(acc) gotoexpr(label::Symbol) = :($Goto{$(QuoteNode(label))}) -function destructure_loop_pre_post(ex) +function destructure_loop_pre_post(ex; multiple_loop_note = "") pre = post = Union{}[] ansvar = :_ if isexpr(ex, :for) @@ -83,7 +83,13 @@ function destructure_loop_pre_post(ex) pre = args[1:i-1] post = args[i+1:end] if find_first_for_loop(post) !== nothing - throw(ArgumentError("Multiple top-level `for` loop found in:\n$ex")) + msg = string( + "Multiple top-level `for` loops found.", + multiple_loop_note, + " Given expression:\n", + ex, + ) + throw(ArgumentError(msg)) end else throw(ArgumentError("Unsupported expression:\n$ex")) diff --git a/src/reduce.jl b/src/reduce.jl index 68845cb9e..9afad270d 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -94,6 +94,7 @@ struct ReduceOpSpec <: OpSpec end ReduceOpSpec(args::Vector{Any}) = ReduceOpSpec(args, Symbol[]) +macroname(::ReduceOpSpec) = Symbol("@reduce") """ @init begin @@ -854,9 +855,6 @@ struct _FLoopInit end transduce(IdentityTransducer(), rf, DefaultInit, coll, maybe_set_simd(exc, simd)), ) -macroname(::ReduceOpSpec) = Symbol("@reduce") -macroname(::CombineOpSpec) = Symbol("@combine") - function Base.print(io::IO, spec::OpSpec) # TODO: print as `do` block print(io, macroname(spec), "(") diff --git a/test/FLoopsTests/src/FLoopsTests.jl b/test/FLoopsTests/src/FLoopsTests.jl index f9cfeea92..4b8af7663 100644 --- a/test/FLoopsTests/src/FLoopsTests.jl +++ b/test/FLoopsTests/src/FLoopsTests.jl @@ -2,6 +2,8 @@ module FLoopsTests using Test +include("utils.jl") + for file in sort([file for file in readdir(@__DIR__) if match(r"^test_.*\.jl$", file) !== nothing]) include(file) diff --git a/test/FLoopsTests/src/test_combine.jl b/test/FLoopsTests/src/test_combine.jl index 75a284d3b..1dfa263b0 100644 --- a/test/FLoopsTests/src/test_combine.jl +++ b/test/FLoopsTests/src/test_combine.jl @@ -5,6 +5,8 @@ using MicroCollections using StaticArrays using Test +using ..Utils: @macroexpand_error + function count_ints_two_pass(indices, ex = nothing) l, h = extrema(indices) n = h - l + 1 @@ -18,10 +20,28 @@ function count_ints_two_pass(indices, ex = nothing) return hist end +valueof(::Val{x}) where {x} = x + +function count_ints_two_pass2(indices, ex = nothing) + l, h = extrema(indices) + n = Val(h - l + 1) + @floop ex begin + @init hist = zero(MVector{valueof(n),Int32}) + for i in indices + hist[i-l+1] += 1 + end + @completebasecase hist = SVector(hist) + @combine hist .+= _ + end + return hist +end + function test_count_ints_two_pass() @testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)] @test count_ints_two_pass(1:3, ex) == [1, 1, 1] @test count_ints_two_pass([1, 2, 4, 1], ex) == [2, 1, 0, 1] + @test count_ints_two_pass2(1:3, ex) == [1, 1, 1] + @test count_ints_two_pass2([1, 2, 4, 1], ex) == [2, 1, 0, 1] end end @@ -94,4 +114,59 @@ function test_count_positive_ints() end end +function test_error_one_for_loop1() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + for y in ys + end + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("Wrap the expressions after the first loop", msg) +end + +function test_error_one_for_loop2() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + function f() + for y in ys + end + end + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("can only contain one `for` loop", msg) +end + +function test_error_mixing_plain_expr_and_completebasecase() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + @completebasecase for y in ys + end + f(ys) + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("cannot be mixed with other expressions", msg) +end + +function test_error_two_completebasecase_macro_calls() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + @completebasecase nothing + @completebasecase nothing + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("Only one `@completebasecase` can be used", msg) +end + end # module diff --git a/test/FLoopsTests/src/utils.jl b/test/FLoopsTests/src/utils.jl new file mode 100644 index 000000000..b55e8e672 --- /dev/null +++ b/test/FLoopsTests/src/utils.jl @@ -0,0 +1,17 @@ +module Utils + +struct NoError end + +macro macroexpand_error(ex) + @gensym err + quote + try + $Base.@eval $Base.@macroexpand $ex + $NoError() + catch $err + $err + end + end |> esc +end + +end # module