@@ -2083,13 +2083,13 @@ end
2083
2083
function pfor (f, R)
2084
2084
lenR = length (R)
2085
2085
chunks = splitrange (lenR, workers ())
2086
- accums = get (task_local_storage (), :JULIA_ACCUMULATOR , ())
2087
- if accums != = ()
2088
- accums = accums [1 ]
2089
- accums = isa (accums , ParallelAccumulator) ? [accums ] : accums
2090
- for acc in accums
2086
+ tls_acc = get (task_local_storage (), :JULIA_ACCUMULATOR , ())
2087
+ if tls_acc != = ()
2088
+ acc_current = tls_acc [1 ]
2089
+ acc_coll = isa (acc_current , ParallelAccumulator) ? [acc_current ] : acc_current
2090
+ for acc in acc_coll
2091
2091
lenR != acc. length && throw (AssertionError (" loop length must equal ParallelAccumulator length" ))
2092
- set_destf (acc, p-> length (chunks[p]))
2092
+ set_f_len_at_pid! (acc, p-> length (chunks[p]))
2093
2093
end
2094
2094
end
2095
2095
@@ -2162,9 +2162,13 @@ type ParallelAccumulator{T}
2162
2162
value:: Nullable{T}
2163
2163
2164
2164
# A function which returns a length value when input the destination pid.
2165
- # Used to serialize the same object with different length values depending
2166
- # on the destination pid.
2167
- destf:: Nullable{Function}
2165
+ # Each worker processes a subset of a paralle for-loop. During serialization
2166
+ # f_len_at_pid is called to retrieve the length of the range that needs to be
2167
+ # processed at pid. On the remote node, we write the locally accumulated value
2168
+ # to the remote channel once len_at_pid values are processed.
2169
+ # On the destination node, this field will be NULL and is used to loosely differentiate
2170
+ # between the original instance on the caller and the deserialized instances on the workers.
2171
+ f_len_at_pid:: Nullable{Function}
2168
2172
2169
2173
chnl:: RemoteChannel
2170
2174
@@ -2179,21 +2183,21 @@ type ParallelAccumulator{T}
2179
2183
ParallelAccumulator (f, len, initial, chnl) =
2180
2184
ParallelAccumulator {T} (f, len, initial, Nullable {Function} (), chnl)
2181
2185
2182
- ParallelAccumulator (f, len, initial, destf , chnl) =
2183
- new (f, len, len, initial, initial, destf , chnl)
2186
+ ParallelAccumulator (f, len, initial, f_len_at_pid , chnl) =
2187
+ new (f, len, len, initial, initial, f_len_at_pid , chnl)
2184
2188
end
2185
2189
2186
- set_destf (pacc:: ParallelAccumulator , f:: Function ) = (pacc. destf = f; pacc)
2190
+ set_f_len_at_pid! (pacc:: ParallelAccumulator , f:: Function ) = (pacc. f_len_at_pid = f; pacc)
2187
2191
2188
2192
function serialize (s:: AbstractSerializer , pacc:: ParallelAccumulator )
2189
2193
serialize_cycle (s, pacc) && return
2190
2194
serialize_type (s, typeof (pacc))
2191
2195
2192
- if isnull (pacc. destf )
2196
+ if isnull (pacc. f_len_at_pid )
2193
2197
error (" Cannot serialize a ParallelAccumulator from a destination node." )
2194
2198
end
2195
2199
2196
- len = get (pacc. destf )(worker_id_from_socket (s. io))
2200
+ len = get (pacc. f_len_at_pid )(worker_id_from_socket (s. io))
2197
2201
2198
2202
serialize (s, pacc. f)
2199
2203
serialize (s, len)
@@ -2214,7 +2218,7 @@ end
2214
2218
2215
2219
function push! (pacc:: ParallelAccumulator , v)
2216
2220
if pacc. pending <= 0
2217
- throw (AssertionError (" Reusing a ParallelAccumulator is not allowed. reset(p::ParallelAccumulator )?" ))
2221
+ throw (AssertionError (" Reusing a ParallelAccumulator is not allowed. reset(acc )?" ))
2218
2222
end
2219
2223
2220
2224
if ! isnull (pacc. value)
@@ -2246,31 +2250,13 @@ end
2246
2250
function reset (pacc:: ParallelAccumulator )
2247
2251
pacc. pending = pacc. length
2248
2252
pacc. value = pacc. initial
2249
- pacc. destf = Nullable {Function} ()
2253
+ pacc. f_len_at_pid = Nullable {Function} ()
2250
2254
pacc
2251
2255
end
2252
2256
2253
2257
macro accumulate (acc, expr)
2254
- if ! (isa (acc, Symbol) || (isa (acc, Expr) && acc. head == :vect ))
2255
- throw (ArgumentError (string (
2256
- " @accumulate : " ,
2257
- " First argument must be a variable name pointing to a ParallelAccumulator " ,
2258
- " or a vector of variable names pointing to ParallelAccumulators. " ,
2259
- " Found : " , typeof (acc))))
2260
- end
2261
-
2262
2258
quote
2263
2259
esc_acc = $ (esc (acc))
2264
- if ! (isa (esc_acc, ParallelAccumulator) ||
2265
- isa (esc_acc, Array{ParallelAccumulator}) ||
2266
- (isa (esc_acc, Array) && all (x-> isa (x, ParallelAccumulator), esc_acc)))
2267
-
2268
- throw (ArgumentError (string (
2269
- " @accumulate : First argument must be a ParallelAccumulator " ,
2270
- " or a vector of ParallelAccumulators. " ,
2271
- " Found : " , typeof (esc_acc))))
2272
-
2273
- end
2274
2260
2275
2261
old_list = get (task_local_storage (), :JULIA_ACCUMULATOR , ())
2276
2262
task_local_storage (:JULIA_ACCUMULATOR , ($ (esc (acc)), old_list))
0 commit comments