@@ -50,98 +50,103 @@ worker has been written to look for such settings:
5050
5151which will save `extra` only if `:extra` is a key in `mon`.
5252"""
53- function driver (outfile:: AbstractString , algorithm:: Vector , img, mon:: Vector )
54- nworkers = length (algorithm)
55- length (mon) == nworkers || error (" Number of monitors must equal number of workers" )
56- use_workerprocs = nworkers > 1 || workerpid (algorithm[1 ]) != myid ()
57- rralgorithm = Array {RemoteChannel} (undef, nworkers)
58- if use_workerprocs
59- # Push the algorithm objects to the worker processes. This elminates
60- # per-iteration serialization penalties, and ensures that any
61- # initalization state is retained.
62- for i = 1 : nworkers
63- alg = algorithm[i]
64- rralgorithm[i] = put! (RemoteChannel (workerpid (alg)), alg)
65- end
66- # Perform any needed worker initialization
67- @sync for i = 1 : nworkers
68- p = workerpid (algorithm[i])
69- @async remotecall_fetch (init!, p, rralgorithm[i])
70- end
71- else
72- init! (algorithm[1 ])
53+ function driver (outfile:: AbstractString , algorithms:: Vector , img, mon:: Vector )
54+ numworkers = length (algorithms)
55+ length (mon) == numworkers || error (" Number of monitors must equal number of algorithms" )
56+ use_workerprocs = numworkers > 1 || workerpid (algorithms[1 ]) != myid ()
57+ pool = use_workerprocs ? map (alg-> alg. workerpid,algorithms) : [myid ()]
58+ wpool = CachingPool (pool) # worker pool for pmap
59+
60+ # Map worker ID to algorithm index
61+ aindices = use_workerprocs ? Dict (map ((alg,aidx)-> (alg. workerpid=> aidx), algorithms, 1 : length (algorithms))... ) :
62+ Dict (myid ()=> 1 )
63+
64+ # Initialize algorithms on workers
65+ println (" Initializing algorithm on workers" )
66+ pmap (wpool, 1 : numworkers) do _
67+ wid = myid ()
68+ init! (algorithms[aindices[wid]])
69+ return nothing
7370 end
74- try
71+
72+ println (" Working on algorithm and saving the result" )
73+ jldopen (outfile, " w" ) do file
74+ dsets = Dict {Symbol,Any} ()
75+ firstsave = Ref (true )
76+ have_unpackable = Ref (false )
7577 n = nimages (img)
76- fs = FormatSpec (" 0$(ndigits (n)) d" ) # group names of unpackable objects
77- jldopen (outfile, " w" ) do file
78- dsets = Dict {Symbol,Any} ()
79- firstsave = SharedArray {Bool} (1 )
80- firstsave[1 ] = true
81- have_unpackable = SharedArray {Bool} (1 )
82- have_unpackable[1 ] = false
83- # Run the jobs
84- nextidx = 0
85- getnextidx () = nextidx += 1
86- writing_mutex = RemoteChannel ()
87- @sync begin
88- for i = 1 : nworkers
89- alg = algorithm[i]
90- @async begin
91- while (idx = getnextidx ()) <= n
92- if use_workerprocs
93- remotecall_fetch (println, workerpid (alg), " Worker " , workerpid (alg), " is working on " , idx)
94- # See https://github.com/JuliaLang/julia/issues/22139
95- tmp = remotecall_fetch (worker, workerpid (alg), rralgorithm[i], img, idx, mon[i])
96- copy_all_but_shared! (mon[i], tmp)
97- else
98- println (" Working on " , idx)
99- mon[1 ] = worker (algorithm[1 ], img, idx, mon[1 ])
100- end
101- # Save the results
102- put! (writing_mutex, true ) # grab the lock
103- try
104- local g
105- if firstsave[]
106- firstsave[] = false
107- have_unpackable[] = initialize_jld! (dsets, file, mon[i], fs, n)
108- end
109- if fetch (have_unpackable[])
110- g = file[string (" stack" , fmt (fs, idx))]
111- end
112- for (k,v) in mon[i]
113- if isa (v, Number)
114- dsets[k][idx] = v
115- continue
116- elseif isa (v, Array) || isa (v, SharedArray)
117- vw = nicehdf5 (v)
118- if eltype (vw) <: BitsType
119- colons = [Colon () for i = 1 : ndims (vw)]
120- dsets[k][colons... , idx] = vw
121- continue
122- end
123- end
124- g[string (k)] = v
125- end
126- finally
127- take! (writing_mutex) # release the lock
128- end
78+ fs = FormatSpec (" 0$(ndigits (n)) d" )
79+
80+ # Channel for passing results from workers to master
81+ results_ch = RemoteChannel (()-> Channel {Tuple{Int,Dict}} (32 ), myid ())
82+
83+ # Writer task (runs on master)
84+ writer_task = @async begin
85+ while true
86+ data = try
87+ take! (results_ch)
88+ catch
89+ break
90+ end
91+ movidx, monres = data
92+
93+ # Initialize datasets on first save
94+ if firstsave[]
95+ firstsave[] = false
96+ have_unpackable[] = initialize_jld! (dsets, file, monres, fs, n)
97+ end
98+
99+ g = have_unpackable[] ? file[string (" stack" , fmt (fs, movidx))] : nothing
100+
101+ # Write all values into the file
102+ for (k,v) in monres
103+ # isa(v, SharedArray) && (@show k)
104+ if isa (v, Number)
105+ dsets[k][movidx] = v
106+ elseif isa (v, Array) || isa (v, SharedArray)
107+ vw = nicehdf5 (v)
108+ if eltype (vw) <: BitsType
109+ colons = [Colon () for _= 1 : ndims (vw)]
110+ dsets[k][colons... , movidx] = vw
111+ else
112+ g[string (k)] = v
129113 end
114+ else
115+ g[string (k)] = v
130116 end
131117 end
118+ # yield() # briefly yield control between @async iterations
132119 end
133120 end
134- finally
135- # Perform any needed worker cleanup
136- if use_workerprocs
137- @sync for i = 1 : nworkers
138- p = workerpid (algorithm[i])
139- @async remotecall_fetch (close!, p, rralgorithm[i])
140- end
141- else
142- close! (algorithm[1 ])
121+
122+ # Main computation with pmap
123+ pmap (wpool, 1 : n) do movidx
124+ wid = myid ()
125+ println (" Worker $wid processing $movidx " )
126+
127+ # Perform computation
128+ tmp = worker (algorithms[aindices[wid]], img, movidx, mon[aindices[wid]])
129+
130+ # Send result back to master for writing
131+ put! (results_ch, (movidx, tmp))
132+ ! use_workerprocs && yield () # this needed if single process
133+ return nothing
143134 end
135+
136+ # Close channel and wait for writer to finish
137+ close (results_ch) # This will cause take!(results_ch) throw an error
138+ wait (writer_task)
144139 end
140+
141+ # Closing algorithms on workers
142+ println (" Closing algorithms on Workers" )
143+ pmap (wpool, 1 : numworkers) do _
144+ wid = myid ()
145+ close! (algorithms[aindices[wid]])
146+ return nothing
147+ end
148+
149+ return nothing
145150end
146151
147152driver (outfile:: AbstractString , algorithm:: AbstractWorker , img, mon:: Dict ) = driver (outfile, [algorithm], img, [mon])
@@ -214,20 +219,19 @@ end
214219
215220mm_package_loader (algorithm:: AbstractWorker ) = mm_package_loader ([algorithm])
216221function mm_package_loader (algorithms:: Vector )
217- nworkers = length (algorithms)
218- use_workerprocs = nworkers > 1 || workerpid (algorithms[1 ]) != myid ()
219- rrdev = Array {RemoteChannel} (undef, nworkers)
220- if use_workerprocs
221- for i = 1 : nworkers
222- dev = algorithms[i]. dev
223- rrdev[i] = put! (RemoteChannel (workerpid (algorithms[i])), dev)
224- end
225- @sync for i = 1 : nworkers
226- p = workerpid (algorithms[i])
227- @async remotecall_fetch (load_mm_package, p, rrdev[i])
228- end
229- else
230- load_mm_package (algorithms[1 ]. dev)
222+ numworkers = length (algorithms)
223+ use_workerprocs = numworkers > 1 || workerpid (algorithms[1 ]) != myid ()
224+ pool = use_workerprocs ? map (alg-> alg. workerpid,algorithms) : [myid ()]
225+ wpool = CachingPool (pool) # worker pool for pmap
226+
227+ # Map worker ID to algorithm index
228+ aindices = use_workerprocs ? Dict (map ((alg,aidx)-> (alg. workerpid=> aidx), algorithms, 1 : length (algorithms))... ) :
229+ Dict (myid ()=> 1 )
230+ # Load a mismatch package on workers
231+ pmap (wpool, 1 : numworkers) do _
232+ wid = myid ()
233+ load_mm_package (algorithms[aindices[wid]]. dev)
234+ return nothing
231235 end
232236 nothing
233237end
0 commit comments