Skip to content

Commit f328a03

Browse files
committed
Rewrite driver function using pmap
old : 744.331 ms (8252 allocations: 28.92 MiB) new : 667.085 ms (6791 allocations: 28.82 MiB)
1 parent bfb9424 commit f328a03

File tree

3 files changed

+104
-100
lines changed

3 files changed

+104
-100
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ jobs:
1313
fail-fast: false
1414
matrix:
1515
version:
16-
- '1.6'
17-
- '1.1'
16+
- '1.10'
1817
- '1'
1918
# - 'nightly'
2019
os:

src/RegisterDriver.jl

Lines changed: 101 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -50,98 +50,103 @@ worker has been written to look for such settings:
5050
5151
which will save `extra` only if `:extra` is a key in `mon`.
5252
"""
53-
function driver(outfile::AbstractString, algorithm::Vector, img, mon::Vector)
54-
nworkers = length(algorithm)
55-
length(mon) == nworkers || error("Number of monitors must equal number of workers")
56-
use_workerprocs = nworkers > 1 || workerpid(algorithm[1]) != myid()
57-
rralgorithm = Array{RemoteChannel}(undef, nworkers)
58-
if use_workerprocs
59-
# Push the algorithm objects to the worker processes. This elminates
60-
# per-iteration serialization penalties, and ensures that any
61-
# initalization state is retained.
62-
for i = 1:nworkers
63-
alg = algorithm[i]
64-
rralgorithm[i] = put!(RemoteChannel(workerpid(alg)), alg)
65-
end
66-
# Perform any needed worker initialization
67-
@sync for i = 1:nworkers
68-
p = workerpid(algorithm[i])
69-
@async remotecall_fetch(init!, p, rralgorithm[i])
70-
end
71-
else
72-
init!(algorithm[1])
53+
function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector)
54+
numworkers = length(algorithms)
55+
length(mon) == numworkers || error("Number of monitors must equal number of algorithms")
56+
use_workerprocs = numworkers > 1 || workerpid(algorithms[1]) != myid()
57+
pool = use_workerprocs ? map(alg->alg.workerpid,algorithms) : [myid()]
58+
wpool = CachingPool(pool) # worker pool for pmap
59+
60+
# Map worker ID to algorithm index
61+
aindices = use_workerprocs ? Dict(map((alg,aidx)->(alg.workerpid=>aidx), algorithms, 1:length(algorithms))...) :
62+
Dict(myid()=>1)
63+
64+
# Initialize algorithms on workers
65+
println("Initializing algorithm on workers")
66+
pmap(wpool, 1:numworkers) do _
67+
wid = myid()
68+
init!(algorithms[aindices[wid]])
69+
return nothing
7370
end
74-
try
71+
72+
println("Working on algorithm and saving the result")
73+
jldopen(outfile, "w") do file
74+
dsets = Dict{Symbol,Any}()
75+
firstsave = Ref(true)
76+
have_unpackable = Ref(false)
7577
n = nimages(img)
76-
fs = FormatSpec("0$(ndigits(n))d") # group names of unpackable objects
77-
jldopen(outfile, "w") do file
78-
dsets = Dict{Symbol,Any}()
79-
firstsave = SharedArray{Bool}(1)
80-
firstsave[1] = true
81-
have_unpackable = SharedArray{Bool}(1)
82-
have_unpackable[1] = false
83-
# Run the jobs
84-
nextidx = 0
85-
getnextidx() = nextidx += 1
86-
writing_mutex = RemoteChannel()
87-
@sync begin
88-
for i = 1:nworkers
89-
alg = algorithm[i]
90-
@async begin
91-
while (idx = getnextidx()) <= n
92-
if use_workerprocs
93-
remotecall_fetch(println, workerpid(alg), "Worker ", workerpid(alg), " is working on ", idx)
94-
# See https://github.com/JuliaLang/julia/issues/22139
95-
tmp = remotecall_fetch(worker, workerpid(alg), rralgorithm[i], img, idx, mon[i])
96-
copy_all_but_shared!(mon[i], tmp)
97-
else
98-
println("Working on ", idx)
99-
mon[1] = worker(algorithm[1], img, idx, mon[1])
100-
end
101-
# Save the results
102-
put!(writing_mutex, true) # grab the lock
103-
try
104-
local g
105-
if firstsave[]
106-
firstsave[] = false
107-
have_unpackable[] = initialize_jld!(dsets, file, mon[i], fs, n)
108-
end
109-
if fetch(have_unpackable[])
110-
g = file[string("stack", fmt(fs, idx))]
111-
end
112-
for (k,v) in mon[i]
113-
if isa(v, Number)
114-
dsets[k][idx] = v
115-
continue
116-
elseif isa(v, Array) || isa(v, SharedArray)
117-
vw = nicehdf5(v)
118-
if eltype(vw) <: BitsType
119-
colons = [Colon() for i = 1:ndims(vw)]
120-
dsets[k][colons..., idx] = vw
121-
continue
122-
end
123-
end
124-
g[string(k)] = v
125-
end
126-
finally
127-
take!(writing_mutex) # release the lock
128-
end
78+
fs = FormatSpec("0$(ndigits(n))d")
79+
80+
# Channel for passing results from workers to master
81+
results_ch = RemoteChannel(()->Channel{Tuple{Int,Dict}}(32), myid())
82+
83+
# Writer task (runs on master)
84+
writer_task = @async begin
85+
while true
86+
data = try
87+
take!(results_ch)
88+
catch
89+
break
90+
end
91+
movidx, monres = data
92+
93+
# Initialize datasets on first save
94+
if firstsave[]
95+
firstsave[] = false
96+
have_unpackable[] = initialize_jld!(dsets, file, monres, fs, n)
97+
end
98+
99+
g = have_unpackable[] ? file[string("stack", fmt(fs, movidx))] : nothing
100+
101+
# Write all values into the file
102+
for (k,v) in monres
103+
# isa(v, SharedArray) && (@show k)
104+
if isa(v, Number)
105+
dsets[k][movidx] = v
106+
elseif isa(v, Array) || isa(v, SharedArray)
107+
vw = nicehdf5(v)
108+
if eltype(vw) <: BitsType
109+
colons = [Colon() for _=1:ndims(vw)]
110+
dsets[k][colons..., movidx] = vw
111+
else
112+
g[string(k)] = v
129113
end
114+
else
115+
g[string(k)] = v
130116
end
131117
end
118+
# yield() # briefly yield control between @async iterations
132119
end
133120
end
134-
finally
135-
# Perform any needed worker cleanup
136-
if use_workerprocs
137-
@sync for i = 1:nworkers
138-
p = workerpid(algorithm[i])
139-
@async remotecall_fetch(close!, p, rralgorithm[i])
140-
end
141-
else
142-
close!(algorithm[1])
121+
122+
# Main computation with pmap
123+
pmap(wpool, 1:n) do movidx
124+
wid = myid()
125+
println("Worker $wid processing $movidx")
126+
127+
# Perform computation
128+
tmp = worker(algorithms[aindices[wid]], img, movidx, mon[aindices[wid]])
129+
130+
# Send result back to master for writing
131+
put!(results_ch, (movidx, tmp))
132+
!use_workerprocs && yield() # this needed if single process
133+
return nothing
143134
end
135+
136+
# Close channel and wait for writer to finish
137+
close(results_ch) # This will cause take!(results_ch) throw an error
138+
wait(writer_task)
144139
end
140+
141+
# Closing algorithms on workers
142+
println("Closing algorithms on Workers")
143+
pmap(wpool, 1:numworkers) do _
144+
wid = myid()
145+
close!(algorithms[aindices[wid]])
146+
return nothing
147+
end
148+
149+
return nothing
145150
end
146151

147152
driver(outfile::AbstractString, algorithm::AbstractWorker, img, mon::Dict) = driver(outfile, [algorithm], img, [mon])
@@ -214,20 +219,19 @@ end
214219

215220
mm_package_loader(algorithm::AbstractWorker) = mm_package_loader([algorithm])
216221
function mm_package_loader(algorithms::Vector)
217-
nworkers = length(algorithms)
218-
use_workerprocs = nworkers > 1 || workerpid(algorithms[1]) != myid()
219-
rrdev = Array{RemoteChannel}(undef, nworkers)
220-
if use_workerprocs
221-
for i = 1:nworkers
222-
dev = algorithms[i].dev
223-
rrdev[i] = put!(RemoteChannel(workerpid(algorithms[i])), dev)
224-
end
225-
@sync for i = 1:nworkers
226-
p = workerpid(algorithms[i])
227-
@async remotecall_fetch(load_mm_package, p, rrdev[i])
228-
end
229-
else
230-
load_mm_package(algorithms[1].dev)
222+
numworkers = length(algorithms)
223+
use_workerprocs = numworkers > 1 || workerpid(algorithms[1]) != myid()
224+
pool = use_workerprocs ? map(alg->alg.workerpid,algorithms) : [myid()]
225+
wpool = CachingPool(pool) # worker pool for pmap
226+
227+
# Map worker ID to algorithm index
228+
aindices = use_workerprocs ? Dict(map((alg,aidx)->(alg.workerpid=>aidx), algorithms, 1:length(algorithms))...) :
229+
Dict(myid()=>1)
230+
# Load a mismatch package on workers
231+
pmap(wpool, 1:numworkers) do _
232+
wid = myid()
233+
load_mm_package(algorithms[aindices[wid]].dev)
234+
return nothing
231235
end
232236
nothing
233237
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using Test, Distributed, SharedArrays
22
using ImageCore, JLD
3-
using RegisterDriver, RegisterWorkerShell
3+
using RegisterWorkerShell
44
using AxisArrays: AxisArray
55

66
driverprocs = addprocs(2)
77
push!(LOAD_PATH, pwd())
88
@sync for p in driverprocs
99
@spawnat p push!(LOAD_PATH, pwd())
1010
end
11+
@everywhere using RegisterDriver
1112
using WorkerDummy
1213

1314
workdir = tempname()

0 commit comments

Comments
 (0)