Skip to content

Commit 96a0c46

Browse files
committed
Wait for input/output handlers to finish when closing a StreamStore
Otherwise it's possible that the scheduler will close the DTask before all outputs have been sent, which would cause the downstream tasks to hang. This is how it could happen: 1. A streaming task starts. 2. The output handler task calls `take!(::ProcessRingBuffer)` on an output buffer, finds it empty, and `yield()`'s. 3. The task executes, pushes its output to the output buffers, reaches `max_evals` and finishes. 4. The scheduler finishes the corresponding DTask. 5. The `take!(::ProcessRingBuffer)` call resumes. The buffer isn't empty anymore but it calls `task_may_cancel(; must_force=true)` before continuing and throws an exception since the scheduler has finished the DTask. The result is that the last output is never sent, and the exeption is swallowed by the output handler started by `initialize_output_stream!()`. 6. Downstream tasks don't get that last result so they never reach `max_evals` and spin forever. Fixed by storing the handler tasks in the `StreamStore` and closing them in `close(::StreamStore)`. Also increased the timeout of the 'Single task running forever' task because it will sometimes timeout before the default 10s is up.
1 parent 09bfd56 commit 96a0c46

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

src/stream.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@ mutable struct StreamStore{T,B}
1212
open::Bool
1313
migrating::Bool
1414
lock::Threads.Condition
15+
16+
input_handlers::Dict{UInt, Task}
17+
output_handlers::Dict{UInt, Task}
18+
1519
StreamStore{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} =
1620
new{T,B}(uid, zeros(Int, 0),
1721
Dict{UInt,Any}(), Dict{UInt,Any}(),
1822
Dict{UInt,B}(), Dict{UInt,B}(),
1923
input_buffer_amount, output_buffer_amount,
2024
Dict{UInt,Any}(), Dict{UInt,Any}(),
21-
true, false, Threads.Condition())
25+
true, false, Threads.Condition(),
26+
Dict{UInt, Task}(), Dict{UInt, Task}())
2227
end
2328

2429
function tid_to_uid(thunk_id)
@@ -164,6 +169,19 @@ function Base.close(store::StreamStore)
164169
end
165170
notify(store.lock)
166171
end
172+
173+
# We have to close the input fetchers for the input handlers to finish
174+
for fetcher in values(store.input_fetchers)
175+
close(fetcher.chan)
176+
end
177+
178+
# Wait for the handlers to finish
179+
for handler in values(store.input_handlers)
180+
wait(handler)
181+
end
182+
for handler in values(store.output_handlers)
183+
wait(handler)
184+
end
167185
end
168186

169187
# FIXME: Just pass Stream directly, rather than its uid
@@ -229,7 +247,7 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S
229247
end
230248
thunk_id = STREAM_THUNK_ID[]
231249
tls = get_tls()
232-
Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin
250+
t = Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin
233251
set_tls!(tls)
234252
STREAM_THUNK_ID[] = thunk_id
235253
try
@@ -247,9 +265,14 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S
247265
@dagdebug STREAM_THUNK_ID[] :stream "input stream closed"
248266
end
249267
end)
268+
269+
our_store.input_handlers[input_uid] = t
270+
250271
return StreamingValue(buffer)
251272
end
273+
252274
initialize_input_stream!(our_store::StreamStore, arg) = arg
275+
253276
function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B}
254277
@assert islocked(our_store.lock)
255278
@dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid"
@@ -260,7 +283,7 @@ function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt
260283
output_fetcher = our_store.output_fetchers[output_uid]
261284
thunk_id = STREAM_THUNK_ID[]
262285
tls = get_tls()
263-
Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin
286+
t = Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin
264287
set_tls!(tls)
265288
STREAM_THUNK_ID[] = thunk_id
266289
try
@@ -282,6 +305,8 @@ function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt
282305
@dagdebug thunk_id :stream "output stream closed"
283306
end
284307
end)
308+
309+
our_store.output_handlers[output_uid] = t
285310
end
286311

287312
Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value)

test/streaming.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ function catch_interrupt(f)
3737
rethrow(err)
3838
end
3939
end
40+
4041
function merge_testset!(inner::Test.DefaultTestSet)
4142
outer = Test.get_testset()
4243
append!(outer.results, inner.results)
4344
outer.n_passed += inner.n_passed
4445
end
45-
function test_finishes(f, message::String; ignore_timeout=false, max_evals=10)
46+
47+
function test_finishes(f, message::String; timeout=10, ignore_timeout=false, max_evals=10)
4648
t = @eval Threads.@spawn begin
4749
tset = nothing
4850
try
@@ -61,7 +63,7 @@ function test_finishes(f, message::String; ignore_timeout=false, max_evals=10)
6163
end
6264
return tset
6365
end
64-
timed_out = timedwait(()->istaskdone(t), 10) == :timed_out
66+
timed_out = timedwait(()->istaskdone(t), timeout) == :timed_out
6567
if timed_out
6668
if !ignore_timeout
6769
@warn "Testing task timed out: $message"
@@ -89,7 +91,7 @@ for idx in 1:5
8991
end
9092

9193
@testset "Single Task Control Flow ($scope_str)" begin
92-
@test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do
94+
@test !test_finishes("Single task running forever"; timeout=15, max_evals=1_000_000, ignore_timeout=true) do
9395
local x
9496
Dagger.spawn_streaming() do
9597
x = Dagger.@spawn scope=rand(scopes) () -> begin
@@ -98,6 +100,7 @@ for idx in 1:5
98100
return y
99101
end
100102
end
103+
101104
@test_throws_unwrap InterruptException fetch(x)
102105
end
103106

0 commit comments

Comments
 (0)