Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.1'
- '1.10'
- '1'
# - 'nightly'
os:
Expand Down
198 changes: 101 additions & 97 deletions src/RegisterDriver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,98 +50,103 @@ worker has been written to look for such settings:

which will save `extra` only if `:extra` is a key in `mon`.
"""
function driver(outfile::AbstractString, algorithm::Vector, img, mon::Vector)
nworkers = length(algorithm)
length(mon) == nworkers || error("Number of monitors must equal number of workers")
use_workerprocs = nworkers > 1 || workerpid(algorithm[1]) != myid()
rralgorithm = Array{RemoteChannel}(undef, nworkers)
if use_workerprocs
# Push the algorithm objects to the worker processes. This elminates
# per-iteration serialization penalties, and ensures that any
# initalization state is retained.
for i = 1:nworkers
alg = algorithm[i]
rralgorithm[i] = put!(RemoteChannel(workerpid(alg)), alg)
end
# Perform any needed worker initialization
@sync for i = 1:nworkers
p = workerpid(algorithm[i])
@async remotecall_fetch(init!, p, rralgorithm[i])
end
else
init!(algorithm[1])
function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector)
numworkers = length(algorithms)
length(mon) == numworkers || error("Number of monitors must equal number of algorithms")
use_workerprocs = numworkers > 1 || workerpid(algorithms[1]) != myid()
pool = use_workerprocs ? map(alg->alg.workerpid,algorithms) : [myid()]
wpool = CachingPool(pool) # worker pool for pmap

# Map worker ID to algorithm index
aindices = use_workerprocs ? Dict(map((alg,aidx)->(alg.workerpid=>aidx), algorithms, 1:length(algorithms))...) :
Dict(myid()=>1)

# Initialize algorithms on workers
println("Initializing algorithm on workers")
pmap(wpool, 1:numworkers) do _
wid = myid()
init!(algorithms[aindices[wid]])
return nothing
end
try

println("Working on algorithm and saving the result")
jldopen(outfile, "w") do file
dsets = Dict{Symbol,Any}()
firstsave = Ref(true)
have_unpackable = Ref(false)
n = nimages(img)
fs = FormatSpec("0$(ndigits(n))d") # group names of unpackable objects
jldopen(outfile, "w") do file
dsets = Dict{Symbol,Any}()
firstsave = SharedArray{Bool}(1)
firstsave[1] = true
have_unpackable = SharedArray{Bool}(1)
have_unpackable[1] = false
# Run the jobs
nextidx = 0
getnextidx() = nextidx += 1
writing_mutex = RemoteChannel()
@sync begin
for i = 1:nworkers
alg = algorithm[i]
@async begin
while (idx = getnextidx()) <= n
if use_workerprocs
remotecall_fetch(println, workerpid(alg), "Worker ", workerpid(alg), " is working on ", idx)
# See https://github.com/JuliaLang/julia/issues/22139
tmp = remotecall_fetch(worker, workerpid(alg), rralgorithm[i], img, idx, mon[i])
copy_all_but_shared!(mon[i], tmp)
else
println("Working on ", idx)
mon[1] = worker(algorithm[1], img, idx, mon[1])
end
# Save the results
put!(writing_mutex, true) # grab the lock
try
local g
if firstsave[]
firstsave[] = false
have_unpackable[] = initialize_jld!(dsets, file, mon[i], fs, n)
end
if fetch(have_unpackable[])
g = file[string("stack", fmt(fs, idx))]
end
for (k,v) in mon[i]
if isa(v, Number)
dsets[k][idx] = v
continue
elseif isa(v, Array) || isa(v, SharedArray)
vw = nicehdf5(v)
if eltype(vw) <: BitsType
colons = [Colon() for i = 1:ndims(vw)]
dsets[k][colons..., idx] = vw
continue
end
end
g[string(k)] = v
end
finally
take!(writing_mutex) # release the lock
end
fs = FormatSpec("0$(ndigits(n))d")

# Channel for passing results from workers to master
results_ch = RemoteChannel(()->Channel{Tuple{Int,Dict}}(32), myid())

# Writer task (runs on master)
writer_task = @async begin
while true
data = try
take!(results_ch)
catch
break
end
movidx, monres = data

# Initialize datasets on first save
if firstsave[]
firstsave[] = false
have_unpackable[] = initialize_jld!(dsets, file, monres, fs, n)
end

g = have_unpackable[] ? file[string("stack", fmt(fs, movidx))] : nothing

# Write all values into the file
for (k,v) in monres
# isa(v, SharedArray) && (@show k)
if isa(v, Number)
dsets[k][movidx] = v
elseif isa(v, Array) || isa(v, SharedArray)
vw = nicehdf5(v)
if eltype(vw) <: BitsType
colons = [Colon() for _=1:ndims(vw)]
dsets[k][colons..., movidx] = vw
else
g[string(k)] = v
end
else
g[string(k)] = v
end
end
# yield() # briefly yield control between @async iterations
end
end
finally
# Perform any needed worker cleanup
if use_workerprocs
@sync for i = 1:nworkers
p = workerpid(algorithm[i])
@async remotecall_fetch(close!, p, rralgorithm[i])
end
else
close!(algorithm[1])

# Main computation with pmap
pmap(wpool, 1:n) do movidx
wid = myid()
println("Worker $wid processing $movidx")

# Perform computation
tmp = worker(algorithms[aindices[wid]], img, movidx, mon[aindices[wid]])

# Send result back to master for writing
put!(results_ch, (movidx, tmp))
!use_workerprocs && yield() # this needed if single process
return nothing
end

# Close channel and wait for writer to finish
close(results_ch) # This will cause take!(results_ch) throw an error
wait(writer_task)
end

# Closing algorithms on workers
println("Closing algorithms on Workers")
pmap(wpool, 1:numworkers) do _
wid = myid()
close!(algorithms[aindices[wid]])
return nothing
end

return nothing
end

driver(outfile::AbstractString, algorithm::AbstractWorker, img, mon::Dict) = driver(outfile, [algorithm], img, [mon])
Expand Down Expand Up @@ -214,20 +219,19 @@ end

mm_package_loader(algorithm::AbstractWorker) = mm_package_loader([algorithm])
function mm_package_loader(algorithms::Vector)
nworkers = length(algorithms)
use_workerprocs = nworkers > 1 || workerpid(algorithms[1]) != myid()
rrdev = Array{RemoteChannel}(undef, nworkers)
if use_workerprocs
for i = 1:nworkers
dev = algorithms[i].dev
rrdev[i] = put!(RemoteChannel(workerpid(algorithms[i])), dev)
end
@sync for i = 1:nworkers
p = workerpid(algorithms[i])
@async remotecall_fetch(load_mm_package, p, rrdev[i])
end
else
load_mm_package(algorithms[1].dev)
numworkers = length(algorithms)
use_workerprocs = numworkers > 1 || workerpid(algorithms[1]) != myid()
pool = use_workerprocs ? map(alg->alg.workerpid,algorithms) : [myid()]
wpool = CachingPool(pool) # worker pool for pmap

# Map worker ID to algorithm index
aindices = use_workerprocs ? Dict(map((alg,aidx)->(alg.workerpid=>aidx), algorithms, 1:length(algorithms))...) :
Dict(myid()=>1)
# Load a mismatch package on workers
pmap(wpool, 1:numworkers) do _
wid = myid()
load_mm_package(algorithms[aindices[wid]].dev)
return nothing
end
nothing
end
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using Test, Distributed, SharedArrays
using ImageCore, JLD
using RegisterDriver, RegisterWorkerShell
using RegisterWorkerShell
using AxisArrays: AxisArray

driverprocs = addprocs(2)
push!(LOAD_PATH, pwd())
@sync for p in driverprocs
@spawnat p push!(LOAD_PATH, pwd())
end
@everywhere using RegisterDriver
using WorkerDummy

workdir = tempname()
Expand Down