Skip to content

Commit

Permalink
Add @completebasecase macro
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Feb 3, 2022
1 parent bcaf1ad commit a163ae9
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 21 deletions.
6 changes: 6 additions & 0 deletions docs/src/reference/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ FLoops.@combine
FLoops.@init
```

## `@completebasecase`

```@docs
FLoops.@completebasecase
```

## [`SequentialEx`, `ThreadedEx` and `DistributedEx` executors](@id executor)

An *executor* controls how a given `@floop` is executed. FLoops.jl re-exports
Expand Down
11 changes: 10 additions & 1 deletion src/FLoops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@ module FLoops
doc
end FLoops

export @floop, @init, @combine, @reduce, DistributedEx, SequentialEx, ThreadedEx
#! format: off
export @floop,
@init,
@combine,
@reduce,
@completebasecase,
DistributedEx,
SequentialEx,
ThreadedEx
#! format: on

using BangBang.Extras: broadcast_inplace!!
using BangBang: materialize!!, push!!
Expand Down
153 changes: 138 additions & 15 deletions src/combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,104 @@ struct CombineOpSpec <: OpSpec
end

CombineOpSpec(args::Vector{Any}) = CombineOpSpec(args, Symbol[])
macroname(::CombineOpSpec) = Symbol("@combine")

# Without a macro like `@completebasecase`, it'd be confusing to have an
# expression such as
#
# @floop begin
# ...
# for x in xs
# ... # executed in parallel loop body
# end
# for y in ys # executed in completebasecase hook
# ...
# end
# ...
# end
#
# i.e., two similar loops have drastically different semantics. The difference
# can be clarified by using the syntax:
#
# @floop begin
# ...
# for x in xs
# ... # executed in parallel loop body
# end
# @completebasecase begin
# for y in ys # executed in completebasecase hook
# ...
# end
# end
# ...
# end
"""
@completebasecase ex
Evaluate expression `ex` at the end of each basecase. The expression `ex` can
only refer to the variables declared by `@init`.
`@completebasecase` can be omitted if `ex` does not contain a `for` loop.
# Examples
```jldoctest
julia> using FLoops
julia> pidigits = string(BigFloat(π; precision = 2^20))[3:end];
julia> @floop begin
@init hist = zeros(Int, 10)
for c in pidigits
i = c - '0' + 1
hist[i] += 1
end
@completebasecase begin
j = 0
y = 0
for (i, x) in pairs(hist) # pretending we don't have `argmax`
if x > y
j = i
y = x
end
end
peaks = [j]
nchunks = [sum(hist)]
end
@combine hist .+= _
@combine peaks = append!(_, _)
@combine nchunks = append!(_, _)
end
```
"""
macro completebasecase(ex)
ex = Expr(:block, __source__, ex)
:(throw($(CompleteBasecaseOp(ex))))
end

struct CompleteBasecaseOp
ex::Expr
end

function extract_spec(ex)
@match ex begin
Expr(:call, throw′, spec::ReduceOpSpec) => spec
Expr(:call, throw′, spec::CombineOpSpec) => spec
Expr(:call, throw′, spec::InitSpec) => spec
Expr(:call, throw′, spec::CompleteBasecaseOp) => spec
_ => nothing
end
end

isa_spec(::Type{T}) where {T} = x -> extract_spec(x) isa T

function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = nothing)
iterspec, body, ansvar, pre, post = destructure_loop_pre_post(ex)
iterspec, body, ansvar, pre, post = destructure_loop_pre_post(
ex;
multiple_loop_note = string(
" Wrap the expressions after the first loop (parallel loop) with",
" `@completebasecase`.",
),
)
@assert ansvar == :_

parallel_loop_ex = @match iterspec begin
Expand All @@ -50,15 +145,6 @@ function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = not
return parallel_loop_ex
end

function extract_spec(ex)
@match ex begin
Expr(:call, throw′, spec::ReduceOpSpec) => spec
Expr(:call, throw′, spec::CombineOpSpec) => spec
Expr(:call, throw′, spec::InitSpec) => spec
_ => nothing
end
end

function as_parallel_combine_loop(
ctx::MacroContext,
pre::Vector,
Expand All @@ -70,6 +156,7 @@ function as_parallel_combine_loop(
executor,
)
@assert simd in (false, true, :ivdep)
foreach(disalow_raw_for_loop_without_completebasecase, post)

init_exprs = []
all_rf_accs = []
Expand All @@ -89,11 +176,27 @@ function as_parallel_combine_loop(
# `next` reducing step function:
base_accs = mapcat(identity, all_rf_accs)

firstcombine = something(
findfirst(x -> extract_spec(x) isa CombineOpSpec, post),
lastindex(post) + 1,
)
firstcombine = something(findfirst(isa_spec(CombineOpSpec), post), lastindex(post) + 1)

completebasecase_exprs = post[firstindex(post):firstcombine-1]
if any(isa_spec(CompleteBasecaseOp), completebasecase_exprs)
# If `CompleteBasecaseOp` is used, this must be the only expression:
let exprs = [x for x in completebasecase_exprs if !(x isa LineNumberNode)],
spec = extract_spec(exprs[1])

if spec isa CompleteBasecaseOp && length(exprs) == 1
completebasecase_exprs = Any[spec.ex]
elseif all(isa_spec(CompleteBasecaseOp), exprs)
error("Only one `@completebasecase` can be used. got:\n", join(exprs, "\n"))
else
error(
"`@completebasecase` cannot be mixed with other expressions.",
" Put everything in `@completebasecase begin ... end`. got:\n",
join(exprs, "\n"),
)
end
end
end

left_accs = []
right_accs = []
Expand All @@ -104,7 +207,8 @@ function as_parallel_combine_loop(
spec = extract_spec(ex)
if !(spec isa CombineOpSpec)
error(
"non-`@combine` expressions must be placed between `for` loop and the first `@combine` expression: ",
"non-`@combine` expressions must be placed between `for` loop and the",
" first `@combine` expression: ",
spec,
)
end
Expand Down Expand Up @@ -279,3 +383,22 @@ function process_combine_op_spec(
# TODO: use accurate line number from `@combine`
return (; left = left, right = right, combine_body = combine_body)
end

function disalow_raw_for_loop_without_completebasecase(@nospecialize(ex))
ex isa Expr || return
extract_spec(ex) === nothing || return
_disalow_raw_for_loop(ex)
end

function _disalow_raw_for_loop(@nospecialize(ex))
ex isa Expr || return
if isexpr(ex, :for)
error(
"`@floop begin ... end` can only contain one `for` loop.",
" Use `@completebasecase begin ... end` to wrap the code after the parallel",
" loop, including the `for` loop. Got:\n",
ex,
)
end
foreach(_disalow_raw_for_loop, ex.args)
end
10 changes: 8 additions & 2 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ end
Goto{label}(acc::T) where {label,T} = Goto{label,T}(acc)
gotoexpr(label::Symbol) = :($Goto{$(QuoteNode(label))})

function destructure_loop_pre_post(ex)
function destructure_loop_pre_post(ex; multiple_loop_note = "")
pre = post = Union{}[]
ansvar = :_
if isexpr(ex, :for)
Expand All @@ -83,7 +83,13 @@ function destructure_loop_pre_post(ex)
pre = args[1:i-1]
post = args[i+1:end]
if find_first_for_loop(post) !== nothing
throw(ArgumentError("Multiple top-level `for` loop found in:\n$ex"))
msg = string(
"Multiple top-level `for` loops found.",
multiple_loop_note,
" Given expression:\n",
ex,
)
throw(ArgumentError(msg))
end
else
throw(ArgumentError("Unsupported expression:\n$ex"))
Expand Down
4 changes: 1 addition & 3 deletions src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct ReduceOpSpec <: OpSpec
end

ReduceOpSpec(args::Vector{Any}) = ReduceOpSpec(args, Symbol[])
macroname(::ReduceOpSpec) = Symbol("@reduce")

"""
@init begin
Expand Down Expand Up @@ -854,9 +855,6 @@ struct _FLoopInit end
transduce(IdentityTransducer(), rf, DefaultInit, coll, maybe_set_simd(exc, simd)),
)

macroname(::ReduceOpSpec) = Symbol("@reduce")
macroname(::CombineOpSpec) = Symbol("@combine")

function Base.print(io::IO, spec::OpSpec)
# TODO: print as `do` block
print(io, macroname(spec), "(")
Expand Down
2 changes: 2 additions & 0 deletions test/FLoopsTests/src/FLoopsTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module FLoopsTests

using Test

include("utils.jl")

for file in
sort([file for file in readdir(@__DIR__) if match(r"^test_.*\.jl$", file) !== nothing])
include(file)
Expand Down
75 changes: 75 additions & 0 deletions test/FLoopsTests/src/test_combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using MicroCollections
using StaticArrays
using Test

using ..Utils: @macroexpand_error

function count_ints_two_pass(indices, ex = nothing)
l, h = extrema(indices)
n = h - l + 1
Expand All @@ -18,10 +20,28 @@ function count_ints_two_pass(indices, ex = nothing)
return hist
end

valueof(::Val{x}) where {x} = x

function count_ints_two_pass2(indices, ex = nothing)
l, h = extrema(indices)
n = Val(h - l + 1)
@floop ex begin
@init hist = zero(MVector{valueof(n),Int32})
for i in indices
hist[i-l+1] += 1
end
@completebasecase hist = SVector(hist)
@combine hist .+= _
end
return hist
end

function test_count_ints_two_pass()
@testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)]
@test count_ints_two_pass(1:3, ex) == [1, 1, 1]
@test count_ints_two_pass([1, 2, 4, 1], ex) == [2, 1, 0, 1]
@test count_ints_two_pass2(1:3, ex) == [1, 1, 1]
@test count_ints_two_pass2([1, 2, 4, 1], ex) == [2, 1, 0, 1]
end
end

Expand Down Expand Up @@ -94,4 +114,59 @@ function test_count_positive_ints()
end
end

function test_error_one_for_loop1()
err = @macroexpand_error @floop begin
@init a = nothing
for x in xs
end
for y in ys
end
end
@test err isa Exception
msg = sprint(showerror, err)
@test occursin("Wrap the expressions after the first loop", msg)
end

function test_error_one_for_loop2()
err = @macroexpand_error @floop begin
@init a = nothing
for x in xs
end
function f()
for y in ys
end
end
end
@test err isa Exception
msg = sprint(showerror, err)
@test occursin("can only contain one `for` loop", msg)
end

function test_error_mixing_plain_expr_and_completebasecase()
err = @macroexpand_error @floop begin
@init a = nothing
for x in xs
end
@completebasecase for y in ys
end
f(ys)
end
@test err isa Exception
msg = sprint(showerror, err)
@test occursin("cannot be mixed with other expressions", msg)
end

function test_error_two_completebasecase_macro_calls()
err = @macroexpand_error @floop begin
@init a = nothing
for x in xs
end
@completebasecase nothing
@completebasecase nothing
end
@test err isa Exception
msg = sprint(showerror, err)
@test occursin("Only one `@completebasecase` can be used", msg)
end

end # module
17 changes: 17 additions & 0 deletions test/FLoopsTests/src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module Utils

struct NoError end

macro macroexpand_error(ex)
@gensym err
quote
try
$Base.@eval $Base.@macroexpand $ex
$NoError()
catch $err
$err
end
end |> esc
end

end # module

0 comments on commit a163ae9

Please sign in to comment.