Skip to content

Commit 5a16805

Browse files
authored
Merge pull request #38405 from JuliaLang/vc/distributed_ts
Make Distributed.jl `Worker` struct thread-safe.
2 parents 02807b2 + 0c073cc commit 5a16805

File tree

7 files changed

+169
-47
lines changed

7 files changed

+169
-47
lines changed

stdlib/Distributed/src/cluster.jl

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,14 @@ end
9595
@enum WorkerState W_CREATED W_CONNECTED W_TERMINATING W_TERMINATED
9696
mutable struct Worker
9797
id::Int
98+
msg_lock::Threads.ReentrantLock # Lock for del_msgs, add_msgs, and gcflag
9899
del_msgs::Array{Any,1}
99100
add_msgs::Array{Any,1}
100101
gcflag::Bool
101102
state::WorkerState
102-
c_state::Condition # wait for state changes
103-
ct_time::Float64 # creation time
104-
conn_func::Any # used to setup connections lazily
103+
c_state::Threads.Condition # wait for state changes, lock for state
104+
ct_time::Float64 # creation time
105+
conn_func::Any # used to setup connections lazily
105106

106107
r_stream::IO
107108
w_stream::IO
@@ -133,7 +134,7 @@ mutable struct Worker
133134
if haskey(map_pid_wrkr, id)
134135
return map_pid_wrkr[id]
135136
end
136-
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
137+
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
137138
w.initialized = Event()
138139
register_worker(w)
139140
w
@@ -143,12 +144,16 @@ mutable struct Worker
143144
end
144145

145146
function set_worker_state(w, state)
146-
w.state = state
147-
notify(w.c_state; all=true)
147+
lock(w.c_state) do
148+
w.state = state
149+
notify(w.c_state; all=true)
150+
end
148151
end
149152

150153
function check_worker_state(w::Worker)
154+
lock(w.c_state)
151155
if w.state === W_CREATED
156+
unlock(w.c_state)
152157
if !isclusterlazy()
153158
if PGRP.topology === :all_to_all
154159
# Since higher pids connect with lower pids, the remote worker
@@ -168,6 +173,8 @@ function check_worker_state(w::Worker)
168173
errormonitor(t)
169174
wait_for_conn(w)
170175
end
176+
else
177+
unlock(w.c_state)
171178
end
172179
end
173180

@@ -186,13 +193,25 @@ function exec_conn_func(w::Worker)
186193
end
187194

188195
function wait_for_conn(w)
196+
lock(w.c_state)
189197
if w.state === W_CREATED
198+
unlock(w.c_state)
190199
timeout = worker_timeout() - (time() - w.ct_time)
191200
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")
192201

193-
@async (sleep(timeout); notify(w.c_state; all=true))
194-
wait(w.c_state)
195-
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
202+
T = Threads.@spawn begin
203+
sleep($timeout)
204+
lock(w.c_state) do
205+
notify(w.c_state; all=true)
206+
end
207+
end
208+
errormonitor(T)
209+
lock(w.c_state) do
210+
wait(w.c_state)
211+
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
212+
end
213+
else
214+
unlock(w.c_state)
196215
end
197216
nothing
198217
end
@@ -471,6 +490,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
471490
# The `launch` method should add an object of type WorkerConfig for every
472491
# worker launched. It provides information required on how to connect
473492
# to it.
493+
494+
# FIXME: launched should be a Channel, launch_ntfy should be a Threads.Condition
495+
# but both are part of the public interface. This means we currently can't use
496+
# `Threads.@spawn` in the code below.
474497
launched = WorkerConfig[]
475498
launch_ntfy = Condition()
476499

@@ -483,7 +506,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
483506
while true
484507
if isempty(launched)
485508
istaskdone(t_launch) && break
486-
@async (sleep(1); notify(launch_ntfy))
509+
@async begin
510+
sleep(1)
511+
notify(launch_ntfy)
512+
end
487513
wait(launch_ntfy)
488514
end
489515

@@ -636,7 +662,12 @@ function create_worker(manager, wconfig)
636662
# require the value of config.connect_at which is set only upon connection completion
637663
for jw in PGRP.workers
638664
if (jw.id != 1) && (jw.id < w.id)
639-
(jw.state === W_CREATED) && wait(jw.c_state)
665+
# wait for wl to join
666+
lock(jw.c_state) do
667+
if jw.state === W_CREATED
668+
wait(jw.c_state)
669+
end
670+
end
640671
push!(join_list, jw)
641672
end
642673
end
@@ -659,7 +690,12 @@ function create_worker(manager, wconfig)
659690
end
660691

661692
for wl in wlist
662-
(wl.state === W_CREATED) && wait(wl.c_state)
693+
if wl.state === W_CREATED
694+
# wait for wl to join
695+
lock(wl.c_state) do
696+
wait(wl.c_state)
697+
end
698+
end
663699
push!(join_list, wl)
664700
end
665701
end
@@ -676,7 +712,11 @@ function create_worker(manager, wconfig)
676712
@async manage(w.manager, w.id, w.config, :register)
677713
# wait for rr_ntfy_join with timeout
678714
timedout = false
679-
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
715+
@async begin
716+
sleep($timeout)
717+
timedout = true
718+
put!(rr_ntfy_join, 1)
719+
end
680720
wait(rr_ntfy_join)
681721
if timedout
682722
error("worker did not connect within $timeout seconds")

stdlib/Distributed/src/macros.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
let nextidx = 0
3+
let nextidx = Threads.Atomic{Int}(0)
44
global nextproc
55
function nextproc()
6-
p = -1
7-
if p == -1
8-
p = workers()[(nextidx % nworkers()) + 1]
9-
nextidx += 1
10-
end
11-
p
6+
idx = Threads.atomic_add!(nextidx, 1)
7+
return workers()[(idx % nworkers()) + 1]
128
end
139
end
1410

stdlib/Distributed/src/managers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
163163
# Wait for all launches to complete.
164164
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
165165
let machine=machine, cnt=cnt
166-
@async try
166+
@async try
167167
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
168168
catch e
169169
print(stderr, "exception launching on machine $(machine) : $(e)\n")

stdlib/Distributed/src/messages.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,20 @@ function flush_gc_msgs(w::Worker)
126126
if !isdefined(w, :w_stream)
127127
return
128128
end
129-
w.gcflag = false
130-
new_array = Any[]
131-
msgs = w.add_msgs
132-
w.add_msgs = new_array
133-
if !isempty(msgs)
134-
remote_do(add_clients, w, msgs)
135-
end
129+
lock(w.msg_lock) do
130+
w.gcflag || return # early exit if someone else got to this
131+
w.gcflag = false
132+
msgs = w.add_msgs
133+
w.add_msgs = Any[]
134+
if !isempty(msgs)
135+
remote_do(add_clients, w, msgs)
136+
end
136137

137-
# del_msgs gets populated by finalizers, so be very careful here about ordering of allocations
138-
# XXX: threading requires this to be atomic
139-
new_array = Any[]
140-
msgs = w.del_msgs
141-
w.del_msgs = new_array
142-
if !isempty(msgs)
143-
#print("sending delete of $msgs\n")
144-
remote_do(del_clients, w, msgs)
138+
msgs = w.del_msgs
139+
w.del_msgs = Any[]
140+
if !isempty(msgs)
141+
remote_do(del_clients, w, msgs)
142+
end
145143
end
146144
end
147145

stdlib/Distributed/src/remotecall.jl

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,42 @@ function del_clients(pairs::Vector)
247247
end
248248
end
249249

250-
const any_gc_flag = Condition()
250+
# The task below is coalescing the `flush_gc_msgs` call
251+
# across multiple producers, see `send_del_client`,
252+
# and `send_add_client`.
253+
# XXX: Is this worth the additional complexity?
254+
# `flush_gc_msgs` has to iterate over all connected workers.
255+
const any_gc_flag = Threads.Condition()
251256
function start_gc_msgs_task()
252-
errormonitor(@async while true
253-
wait(any_gc_flag)
254-
flush_gc_msgs()
255-
end)
257+
errormonitor(
258+
Threads.@spawn begin
259+
while true
260+
lock(any_gc_flag) do
261+
wait(any_gc_flag)
262+
flush_gc_msgs() # handles throws internally
263+
end
264+
end
265+
end
266+
)
256267
end
257268

269+
# Function can be called within a finalizer
258270
function send_del_client(rr)
259271
if rr.where == myid()
260272
del_client(rr)
261273
elseif id_in_procs(rr.where) # process only if a valid worker
262274
w = worker_from_id(rr.where)::Worker
263-
push!(w.del_msgs, (remoteref_id(rr), myid()))
264-
w.gcflag = true
265-
notify(any_gc_flag)
275+
msg = (remoteref_id(rr), myid())
276+
# We cannot acquire locks from finalizers
277+
Threads.@spawn begin
278+
lock(w.msg_lock) do
279+
push!(w.del_msgs, msg)
280+
w.gcflag = true
281+
end
282+
lock(any_gc_flag) do
283+
notify(any_gc_flag)
284+
end
285+
end
266286
end
267287
end
268288

@@ -288,9 +308,13 @@ function send_add_client(rr::AbstractRemoteRef, i)
288308
# to the processor that owns the remote ref. it will add_client
289309
# itself inside deserialize().
290310
w = worker_from_id(rr.where)
291-
push!(w.add_msgs, (remoteref_id(rr), i))
292-
w.gcflag = true
293-
notify(any_gc_flag)
311+
lock(w.msg_lock) do
312+
push!(w.add_msgs, (remoteref_id(rr), i))
313+
w.gcflag = true
314+
end
315+
lock(any_gc_flag) do
316+
notify(any_gc_flag)
317+
end
294318
end
295319
end
296320

stdlib/Distributed/test/distributed_exec.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,4 +1696,5 @@ include("splitrange.jl")
16961696
# Run topology tests last after removing all workers, since a given
16971697
# cluster at any time only supports a single topology.
16981698
rmprocs(workers())
1699+
include("threads.jl")
16991700
include("topology.jl")

stdlib/Distributed/test/threads.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using Test
2+
using Distributed, Base.Threads
3+
using Base.Iterators: product
4+
5+
exeflags = ("--startup-file=no",
6+
"--check-bounds=yes",
7+
"--depwarn=error",
8+
"--threads=2")
9+
10+
function call_on(f, wid, tid)
11+
remotecall(wid) do
12+
t = Task(f)
13+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
14+
schedule(t)
15+
@assert threadid(t) == tid
16+
t
17+
end
18+
end
19+
20+
# Run function on process holding the data to only serialize the result of f.
21+
# This becomes useful for things that cannot be serialized (e.g. running tasks)
22+
# or that would be unnecessarily big if serialized.
23+
fetch_from_owner(f, rr) = remotecall_fetch(f fetch, rr.where, rr)
24+
25+
isdone(rr) = fetch_from_owner(istaskdone, rr)
26+
isfailed(rr) = fetch_from_owner(istaskfailed, rr)
27+
28+
@testset "RemoteChannel allows put!/take! from thread other than 1" begin
29+
ws = ts = product(1:2, 1:2)
30+
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
31+
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
32+
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
33+
procs_added = addprocs(2; exeflags, lazy=true)
34+
@everywhere procs_added using Base.Threads
35+
36+
p1 = procs_added[w1]
37+
p2 = procs_added[w2]
38+
chan_id = first(procs_added)
39+
chan = RemoteChannel(chan_id)
40+
send = call_on(p1, t1) do
41+
put!(chan, nothing)
42+
end
43+
recv = call_on(p2, t2) do
44+
take!(chan)
45+
end
46+
47+
# Wait on the spawned tasks on the owner
48+
@sync begin
49+
Threads.@spawn fetch_from_owner(wait, recv)
50+
Threads.@spawn fetch_from_owner(wait, send)
51+
end
52+
53+
# Check the tasks
54+
@test isdone(send)
55+
@test isdone(recv)
56+
57+
@test !isfailed(send)
58+
@test !isfailed(recv)
59+
60+
rmprocs(procs_added)
61+
end
62+
end
63+
end

0 commit comments

Comments
 (0)