Skip to content

RFC: "for-loop" compliant @parallel for.... take 2 #20259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,9 @@ for f in (:airyai, :airyaiprime, :airybi, :airybiprime, :airyaix, :airyaiprimex,
end
end

# TODO: reducer mode from `@parallel for` is now deprecated. Should be removed from
# the implementation in distributed/macros.jl

# END 0.6 deprecations

# BEGIN 1.0 deprecations
Expand Down
5 changes: 3 additions & 2 deletions base/distributed/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
module Distributed

# imports for extension
import Base: getindex, wait, put!, take!, fetch, isready, push!, length,
hash, ==, connect, kill, serialize, deserialize, close
import Base: getindex, setindex!, wait, put!, take!, fetch, isready, push!, length,
hash, ==, connect, kill, serialize, deserialize, close, reduce

# imports for use
using Base: Process, Semaphore, JLOptions, AnyDict, buffer_writes, wait_connected,
Expand All @@ -27,6 +27,7 @@ export
clear!,
ClusterManager,
default_worker_pool,
DistributedRef,
init_worker,
interrupt,
launch,
Expand Down
235 changes: 201 additions & 34 deletions base/distributed/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,42 +98,160 @@ end
eval_ew_expr(ex) = (eval(Main, ex); nothing)

# Statically split range [1,N] into equal sized chunks for np processors
function splitrange(N::Int, np::Int)
function splitrange(N::Int, wlist::Array)
np = length(wlist)
each = div(N,np)
extras = rem(N,np)
nchunks = each > 0 ? np : extras
chunks = Array{UnitRange{Int}}(nchunks)
chunks = Dict{Int, UnitRange{Int}}()
lo = 1
for i in 1:nchunks
hi = lo + each - 1
if extras > 0
hi += 1
extras -= 1
end
chunks[i] = lo:hi
chunks[wlist[i]] = lo:hi
lo = hi+1
end
return chunks
end

function preduce(reducer, f, R)
N = length(R)
chunks = splitrange(N, nworkers())
all_w = workers()[1:length(chunks)]
"""
DistributedRef(initial)

Constructs a reducing accumulator designed to be used in conjunction with `@parallel`
for-loops. Takes a single argument specifying an initial value.

The body of the `@parallel` for-loop can refer to multiple `DistributedRef`s.
The 0-arg indexation syntax, `[]` is used to fetch from and assign to an accumulator in the loop body.
For example `acc[] = acc[] + i` will fetch the current, locally accumulated value in `acc`, add
the value of `i` and set the new accumulated value back in `acc`.

See [`@parallel`](@ref) for more details.

DistributedRefs can also be used independent of `@parallel` loops.

```julia
acc = DistributedRef(1)
@sync for p in workers()
@spawnat p begin
for i in 1:10
acc[] += 1 # Local accumulation on each worker
end
push!(acc) # Explicit push of local accumulation to driver node (typically node 1)
end
end
reduce(+, acc)
```

Usage of DistributedRefs independent of a `@parallel` construct must observe the following:
- All remote tasks must be completed before calling `reduce` to retrieve the accumulated value. In the
example above, this is achieved by [`@sync`](@ref).
- `push!(acc)` must be explictly called once on each worker. This pushes the locally accumulated value
to the node driving the computation.

Note that the optional `initial` value is used on all workers. For example, if the reducing function is `+`,
`DistributedRef(25)` will add a total of `25*nworkers()` to the final result.
"""
mutable struct DistributedRef{T}
initial::T
value::T
workers::Set{Int} # Used on caller to detect arrival of all parts
chnl::RemoteChannel
hval::Int # change hash value to ensure globals are serialized everytime

DistributedRef{T}(initial, chnl) where T = new(initial, initial, Set{Int}(), chnl, 0)
end

DistributedRef{T}() where T = DistributedRef{T}(zero(T))
DistributedRef{T}(initial::T) = DistributedRef{T}(initial, RemoteChannel(()->Channel{Tuple}(Inf)))

const dref_registry=Dict{RRID, Array{DistributedRef}}()

getindex(dref::DistributedRef) = dref.value
setindex!(dref::DistributedRef, v) = (dref.value = v)

hash(dref::DistributedRef, h::UInt) = hash(dref.hval, hash(dref.chnl, h))


"""
push!(dref::DistributedRef)

Pushes the locally accumulated value to the calling node. Must be called once on each worker
when a DistributedRef is used independent of a [`@parallel`](@ref) construct.
"""
push!(dref::DistributedRef) = put_nowait!(dref.chnl, (myid(), dref.value))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push doesn't make sense, it's not a growing collection - send would be better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.


function serialize(s::AbstractSerializer, dref::DistributedRef)
serialize_type(s, typeof(dref))

serialize(s, dref.initial)
serialize(s, dref.chnl)
rrid = get(task_local_storage(), :JULIA_PACC_TRACKER, ())
serialize(s, rrid)

destpid = worker_id_from_socket(s.io)
push!(dref.workers, destpid)
nothing
end

function deserialize(s::AbstractSerializer, t::Type{T}) where T <: DistributedRef
initial = deserialize(s)
chnl = deserialize(s)
rrid = deserialize(s)

dref = T(initial, chnl)

global dref_registry
rrid != () && push!(get!(dref_registry, rrid, []), dref)
dref
end


"""
reduce(op, dref::DistributedRef)

Performs a final reduction on the calling node of values accumulated by
a [`@parallel`](@ref) invocation and returns the reduced value.
"""
reduce(op, dref::DistributedRef) = reduce(op, dref.initial, dref)

function reduce{T}(op, v0, dref::DistributedRef{T})
length(dref.workers) == 0 && return dref.value # local execution, no workers present

results = T[]
while length(dref.workers) > 0
(pid, v) = take!(dref.chnl)
@assert pid in dref.workers
delete!(dref.workers, pid)
push!(results, v)
end
dref.hval += 1
dref.value = reduce(op, v0, results)
dref.value
end

"""
clear!(dref::DistributedRef)

Clears a DistributedRef object enabling its reuse in a subsequent call.
"""
function clear!(dref::DistributedRef)
dref.value = dref.initial
dref.hval += 1
dref
end

function preduce(reducer, f, R)
w_exec = Task[]
for (idx,pid) in enumerate(all_w)
t = Task(()->remotecall_fetch(f, pid, reducer, R, first(chunks[idx]), last(chunks[idx])))
schedule(t)
for (pid, r) in splitrange(length(R), workers())
t = @schedule remotecall_fetch(f, pid, reducer, R, first(r), last(r))
push!(w_exec, t)
end
reduce(reducer, [wait(t) for t in w_exec])
end

function pfor(f, R)
[@spawn f(R, first(c), last(c)) for c in splitrange(length(R), nworkers())]
end

function make_preduce_body(var, body)
quote
function (reducer, R, lo::Int, hi::Int)
Expand All @@ -149,11 +267,32 @@ function make_preduce_body(var, body)
end
end

function pfor(f, R)
lenR = length(R)
chunks = splitrange(lenR, workers())
rrid = RRID()
task_local_storage(:JULIA_PACC_TRACKER, rrid)
[remotecall(f, p, R, first(c), last(c), rrid) for (p,c) in chunks]
end

function make_pfor_body(var, body)
quote
function (R, lo::Int, hi::Int)
for $(esc(var)) in R[lo:hi]
$(esc(body))
function (R, lo::Int, hi::Int, rrid)
global dref_registry
dref_list = get(dref_registry, rrid, DistributedRef[])
delete!(dref_registry, rrid)
try
for $(esc(var)) in R[lo:hi]
$(esc(body))
end
catch e
for p2 in dref_list
put_nowait!(p2.chnl, (0, p2.initial))
end
rethrow(e)
end
for p2 in dref_list
put_nowait!(p2.chnl, (myid(), p2.value))
end
end
end
Expand All @@ -164,27 +303,56 @@ end

A parallel for loop of the form :

@parallel [reducer] for var = range
@parallel for var = range
body
end

The specified range is partitioned and locally executed across all workers. In case an
optional reducer function is specified, `@parallel` performs local reductions on each worker
with a final reduction on the calling process.

Note that without a reducer function, `@parallel` executes asynchronously, i.e. it spawns
independent tasks on all available workers and returns immediately without waiting for
completion. To wait for completion, prefix the call with [`@sync`](@ref), like :

@sync @parallel for var = range
body
end
The loop is executed in parallel across all workers, with each worker executing a subset
of the range. The call waits for completion of all iterations on all workers before returning.
Any updates to variables outside the loop body are not reflected on the calling node.
However, this is a common requirement and can be achieved in a couple of ways. One, the loop body
can update shared arrays, wherein the updates are visible on all nodes mapping the array. Second,
[`DistributedRef`](@ref) objects can be used to collect computed values efficiently.
The former can be used only on a single node (with multiple workers mapping the same shared segment), while
the latter can be used when a computation is distributed across nodes.

```jldoctest
julia> a = SharedArray{Float64}(4);

julia> c = 10;

julia> @parallel for i=1:4
a[i] = i + c
end

julia> a
4-element SharedArray{Float64,1}:
11.0
12.0
13.0
14.0
```

```jldoctest
julia> acc = DistributedRef(0);

julia> c = 100;

julia> @parallel for i in 1:10
j = 2i + c
acc[] += j
end;

julia> reduce(+, acc)
1110
```
"""
macro parallel(args...)
na = length(args)
if na==1
if na == 1
loop = args[1]
elseif na==2
elseif na == 2
depwarn("@parallel with a reducer is deprecated. Use DistributedRefs for reduction.", Symbol("@parallel"))
reducer = args[1]
loop = args[2]
else
Expand All @@ -196,11 +364,10 @@ macro parallel(args...)
var = loop.args[1].args[1]
r = loop.args[1].args[2]
body = loop.args[2]
if na==1
thecall = :(pfor($(make_pfor_body(var, body)), $(esc(r))))
if na == 1
thecall = :(foreach(wait, pfor($(make_pfor_body(var, body)), $(esc(r)))))
else
thecall = :(preduce($(esc(reducer)), $(make_preduce_body(var, body)), $(esc(r))))
end
thecall
end

16 changes: 10 additions & 6 deletions base/distributed/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,13 @@ remote_do(f, id::Integer, args...; kwargs...) = remote_do(f, worker_from_id(id),

# have the owner of rr call f on it
function call_on_owner(f, rr::AbstractRemoteRef, args...)
rid = remoteref_id(rr)
if rr.where == myid()
f(rid, args...)
else
remotecall_fetch(f, rr.where, rid, args...)
end
rr.where == myid() && return f(remoteref_id(rr), args...)
return remotecall_fetch(f, rr.where, remoteref_id(rr), args...)
end

function call_on_owner_nowait(f, rr::AbstractRemoteRef, args...)
rr.where == myid() && return f(remoteref_id(rr), args...)
return remote_do(f, rr.where, remoteref_id(rr), args...)
end

function wait_ref(rid, callee, args...)
Expand Down Expand Up @@ -525,6 +526,9 @@ Returns its first argument.
"""
put!(rr::RemoteChannel, args...) = (call_on_owner(put_ref, rr, args...); rr)

# Returns immediately, does not guarantee a successful put!
put_nowait!(rr::RemoteChannel, args...) = (call_on_owner_nowait(put_ref, rr, args...); rr)

# take! is not supported on Future

take!(rv::RemoteValue, args...) = take!(rv.c, args...)
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,7 @@ export
clear!,
ClusterManager,
default_worker_pool,
DistributedRef,
init_worker,
interrupt,
launch,
Expand Down
Loading