Skip to content

Commit c720d73

Browse files
committed
remove checks for acc types. [ci skip]
1 parent 2f21f3a commit c720d73

File tree

1 file changed

+20
-34
lines changed

1 file changed

+20
-34
lines changed

base/multi.jl

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,13 +2083,13 @@ end
20832083
function pfor(f, R)
20842084
lenR = length(R)
20852085
chunks = splitrange(lenR, workers())
2086-
accums = get(task_local_storage(), :JULIA_ACCUMULATOR, ())
2087-
if accums !== ()
2088-
accums = accums[1]
2089-
accums = isa(accums, ParallelAccumulator) ? [accums] : accums
2090-
for acc in accums
2086+
tls_acc = get(task_local_storage(), :JULIA_ACCUMULATOR, ())
2087+
if tls_acc !== ()
2088+
acc_current = tls_acc[1]
2089+
acc_coll = isa(acc_current, ParallelAccumulator) ? [acc_current] : acc_current
2090+
for acc in acc_coll
20912091
lenR != acc.length && throw(AssertionError("loop length must equal ParallelAccumulator length"))
2092-
set_destf(acc, p->length(chunks[p]))
2092+
set_f_len_at_pid!(acc, p->length(chunks[p]))
20932093
end
20942094
end
20952095

@@ -2162,9 +2162,13 @@ type ParallelAccumulator{T}
21622162
value::Nullable{T}
21632163

21642164
# A function which returns a length value when input the destination pid.
2165-
# Used to serialize the same object with different length values depending
2166-
# on the destination pid.
2167-
destf::Nullable{Function}
2165+
# Each worker processes a subset of a paralle for-loop. During serialization
2166+
# f_len_at_pid is called to retrieve the length of the range that needs to be
2167+
# processed at pid. On the remote node, we write the locally accumulated value
2168+
# to the remote channel once len_at_pid values are processed.
2169+
# On the destination node, this field will be NULL and is used to loosely differentiate
2170+
# between the original instance on the caller and the deserialized instances on the workers.
2171+
f_len_at_pid::Nullable{Function}
21682172

21692173
chnl::RemoteChannel
21702174

@@ -2179,21 +2183,21 @@ type ParallelAccumulator{T}
21792183
ParallelAccumulator(f, len, initial, chnl) =
21802184
ParallelAccumulator{T}(f, len, initial, Nullable{Function}(), chnl)
21812185

2182-
ParallelAccumulator(f, len, initial, destf, chnl) =
2183-
new(f, len, len, initial, initial, destf, chnl)
2186+
ParallelAccumulator(f, len, initial, f_len_at_pid, chnl) =
2187+
new(f, len, len, initial, initial, f_len_at_pid, chnl)
21842188
end
21852189

2186-
set_destf(pacc::ParallelAccumulator, f::Function) = (pacc.destf = f; pacc)
2190+
set_f_len_at_pid!(pacc::ParallelAccumulator, f::Function) = (pacc.f_len_at_pid = f; pacc)
21872191

21882192
function serialize(s::AbstractSerializer, pacc::ParallelAccumulator)
21892193
serialize_cycle(s, pacc) && return
21902194
serialize_type(s, typeof(pacc))
21912195

2192-
if isnull(pacc.destf)
2196+
if isnull(pacc.f_len_at_pid)
21932197
error("Cannot serialize a ParallelAccumulator from a destination node.")
21942198
end
21952199

2196-
len = get(pacc.destf)(worker_id_from_socket(s.io))
2200+
len = get(pacc.f_len_at_pid)(worker_id_from_socket(s.io))
21972201

21982202
serialize(s, pacc.f)
21992203
serialize(s, len)
@@ -2214,7 +2218,7 @@ end
22142218

22152219
function push!(pacc::ParallelAccumulator, v)
22162220
if pacc.pending <= 0
2217-
throw(AssertionError("Reusing a ParallelAccumulator is not allowed. reset(p::ParallelAccumulator)?"))
2221+
throw(AssertionError("Reusing a ParallelAccumulator is not allowed. reset(acc)?"))
22182222
end
22192223

22202224
if !isnull(pacc.value)
@@ -2246,31 +2250,13 @@ end
22462250
function reset(pacc::ParallelAccumulator)
22472251
pacc.pending = pacc.length
22482252
pacc.value = pacc.initial
2249-
pacc.destf = Nullable{Function}()
2253+
pacc.f_len_at_pid = Nullable{Function}()
22502254
pacc
22512255
end
22522256

22532257
macro accumulate(acc, expr)
2254-
if !(isa(acc, Symbol) || (isa(acc, Expr) && acc.head == :vect))
2255-
throw(ArgumentError(string(
2256-
"@accumulate : ",
2257-
"First argument must be a variable name pointing to a ParallelAccumulator ",
2258-
"or a vector of variable names pointing to ParallelAccumulators. ",
2259-
"Found : ", typeof(acc))))
2260-
end
2261-
22622258
quote
22632259
esc_acc = $(esc(acc))
2264-
if !(isa(esc_acc, ParallelAccumulator) ||
2265-
isa(esc_acc, Array{ParallelAccumulator}) ||
2266-
(isa(esc_acc, Array) && all(x->isa(x, ParallelAccumulator), esc_acc)))
2267-
2268-
throw(ArgumentError(string(
2269-
"@accumulate : First argument must be a ParallelAccumulator ",
2270-
"or a vector of ParallelAccumulators. ",
2271-
"Found : ", typeof(esc_acc))))
2272-
2273-
end
22742260

22752261
old_list = get(task_local_storage(), :JULIA_ACCUMULATOR, ())
22762262
task_local_storage(:JULIA_ACCUMULATOR, ($(esc(acc)), old_list))

0 commit comments

Comments
 (0)