Skip to content

RFC/WIP: "for-loop" compliant @parallel for [ci skip] #20094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ export
ObjectIdDict,
OrdinalRange,
Pair,
ParallelAccumulator,
PartialQuickSort,
PollingFileWatcher,
QuickSort,
Expand Down Expand Up @@ -1359,6 +1360,7 @@ export
@threadcall,

# multiprocessing
@accumulate,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required.

A final cleanup, docs and tests are pending.

@spawn,
@spawnat,
@fetch,
Expand Down
168 changes: 150 additions & 18 deletions base/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2037,42 +2037,34 @@ end
eval_ew_expr(ex) = (eval(Main, ex); nothing)

# Statically split range [1,N] into equal sized chunks for np processors
function splitrange(N::Int, np::Int)
function splitrange(N::Int, wlist::Array)
np = length(wlist)
each = div(N,np)
extras = rem(N,np)
nchunks = each > 0 ? np : extras
chunks = Array{UnitRange{Int}}(nchunks)
chunks = Dict{Int, UnitRange{Int}}()
lo = 1
for i in 1:nchunks
hi = lo + each - 1
if extras > 0
hi += 1
extras -= 1
end
chunks[i] = lo:hi
chunks[wlist[i]] = lo:hi
lo = hi+1
end
return chunks
end

function preduce(reducer, f, R)
N = length(R)
chunks = splitrange(N, nworkers())
all_w = workers()[1:length(chunks)]

w_exec = Task[]
for (idx,pid) in enumerate(all_w)
t = Task(()->remotecall_fetch(f, pid, reducer, R, first(chunks[idx]), last(chunks[idx])))
schedule(t)
for (pid, r) in splitrange(length(R), workers())
t = @schedule remotecall_fetch(f, pid, reducer, R, first(r), last(r))
push!(w_exec, t)
end
reduce(reducer, [wait(t) for t in w_exec])
end

function pfor(f, R)
[@spawn f(R, first(c), last(c)) for c in splitrange(length(R), nworkers())]
end

function make_preduce_body(var, body)
quote
function (reducer, R, lo::Int, hi::Int)
Expand All @@ -2088,6 +2080,44 @@ function make_preduce_body(var, body)
end
end

function pfor(f, R)
lenR = length(R)
chunks = splitrange(lenR, workers())

# identify all accumulators
accs = ParallelAccumulator[]

# locals closed over
for i in 1:nfields(f)
v = getfield(f, i)
isa(v, ParallelAccumulator) && push!(accs, v)
end

# globals referenced
for x in code_lowered(f, (UnitRange, Int, Int))[1].code
isa(x, Expr) && search_glb_accs(x, accs)
end

for acc in accs
lenR != acc.length && throw(AssertionError("loop length must equal ParallelAccumulator length"))
set_f_len_at_pid!(acc, p->length(chunks[p]))
end

[remotecall(f, p, R, first(c), last(c)) for (p,c) in chunks]
end

function search_glb_accs(ex::Expr, accs)
for x in ex.args
if isa(x, GlobalRef)
if x.mod == Main && isdefined(Main, x.name)
v = getfield(Main, x.name)
isa(v, ParallelAccumulator) && push!(accs, v)
end
end
isa(x, Expr) && search_glb_accs(x, accs)
end
end

function make_pfor_body(var, body)
quote
function (R, lo::Int, hi::Int)
Expand Down Expand Up @@ -2121,9 +2151,10 @@ completion. To wait for completion, prefix the call with [`@sync`](@ref), like :
"""
macro parallel(args...)
na = length(args)
if na==1
if na == 1
loop = args[1]
elseif na==2
elseif na == 2
depwarn("@parallel with a reducer is deprecated. Use ParallelAccumulators for reduction.", :@parallel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'll be a conflict magnet so it can wait a bit, but whenever this proceeds, would be good to leave a comment in the appropriate section of deprecated.jl as a reminder to remove this code

reducer = args[1]
loop = args[2]
else
Expand All @@ -2135,14 +2166,115 @@ macro parallel(args...)
var = loop.args[1].args[1]
r = loop.args[1].args[2]
body = loop.args[2]
if na==1
thecall = :(pfor($(make_pfor_body(var, body)), $(esc(r))))
if na == 1
thecall = :(foreach(wait, pfor($(make_pfor_body(var, body)), $(esc(r)))))
else
thecall = :(preduce($(esc(reducer)), $(make_preduce_body(var, body)), $(esc(r))))
end
localize_vars(thecall)
end

type ParallelAccumulator{T}
f::Function

length::Int
pending::Int

initial::Nullable{T}
value::Nullable{T}

# A function which returns a length value when input the destination pid.
# Each worker processes a subset of a paralle for-loop. During serialization
# f_len_at_pid is called to retrieve the length of the range that needs to be
# processed at pid. On the remote node, we write the locally accumulated value
# to the remote channel once len_at_pid values are processed.
# On the destination node, this field will be NULL and is used to loosely differentiate
# between the original instance on the caller and the deserialized instances on the workers.
f_len_at_pid::Nullable{Function}

chnl::RemoteChannel

ParallelAccumulator(f, len) = ParallelAccumulator{T}(f, len, Nullable{T}())

ParallelAccumulator(f, len, initial::T) =
ParallelAccumulator{T}(f, len, Nullable{T}(initial))

ParallelAccumulator(f, len, initial::Nullable{T}) =
ParallelAccumulator{T}(f, len, initial, RemoteChannel(()->Channel{Tuple{Int, T}}(Inf)))

ParallelAccumulator(f, len, initial, chnl) =
ParallelAccumulator{T}(f, len, initial, Nullable{Function}(), chnl)

ParallelAccumulator(f, len, initial, f_len_at_pid, chnl) =
new(f, len, len, initial, initial, f_len_at_pid, chnl)
end

set_f_len_at_pid!(pacc::ParallelAccumulator, f::Function) = (pacc.f_len_at_pid = f; pacc)

function serialize(s::AbstractSerializer, pacc::ParallelAccumulator)
serialize_cycle(s, pacc) && return
serialize_type(s, typeof(pacc))

if isnull(pacc.f_len_at_pid)
error("Cannot serialize a ParallelAccumulator from a destination node.")
end

len = get(pacc.f_len_at_pid)(worker_id_from_socket(s.io))

serialize(s, pacc.f)
serialize(s, len)
serialize(s, pacc.initial)
serialize(s, pacc.chnl)
nothing
end

function deserialize(s::AbstractSerializer, t::Type{T}) where T <: ParallelAccumulator
f = deserialize(s)
len = deserialize(s)
initial = deserialize(s)
chnl = deserialize(s)

return T(f, len, initial, chnl)
end

function push!(pacc::ParallelAccumulator, v)
if pacc.pending <= 0
throw(AssertionError("Reusing a ParallelAccumulator is not allowed. reset(acc)?"))
end

if !isnull(pacc.value)
pacc.value = pacc.f(get(pacc.value), v)
else
pacc.value = pacc.f(v)
end
pacc.pending -= 1

if pacc.pending == 0
put!(pacc.chnl, (pacc.length, get(pacc.value)))
end
pacc
end

function wait(pacc::ParallelAccumulator)
while pacc.pending > 0
(n, v) = take!(pacc.chnl)
pacc.pending -= n
if isnull(pacc.value)
pacc.value = pacc.f(v)
else
pacc.value = pacc.f(get(pacc.value), v)
end
end
return get(pacc.value)
end

function reset(pacc::ParallelAccumulator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another specialization of the existing exported reset. I am OK with a new export reset! too.

pacc.pending = pacc.length
pacc.value = pacc.initial
pacc.f_len_at_pid = Nullable{Function}()
pacc
end


function check_master_connect()
timeout = worker_timeout()
Expand Down