Skip to content

streaming: Add DAG teardown option #584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,4 +427,5 @@ wait(t)
The above example demonstrates a streaming region that generates random numbers
continuously and writes each random number to a file. The streaming region is
terminated when a random number less than 0.01 is generated, which is done by
calling `Dagger.finish_stream()` (this exits the current streaming task).
calling `Dagger.finish_stream()` (this terminates the current task, and will
also terminate all streaming tasks launched by `spawn_streaming`).
5 changes: 2 additions & 3 deletions docs/src/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ end
```

If you want to stop the streaming DAG and tear it all down, you can call
`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to
terminate each streaming task. In the future, a more convenient way to tear
down a full DAG will be added; for now, each task must be cancelled individually.
`Dagger.cancel!(all_vals[1])` (or with any other task in the streaming DAG) to
terminate all streaming tasks.

Alternatively, tasks can stop themselves from the inside with
`finish_stream`, optionally returning a value that can be `fetch`'d. Let's
Expand Down
26 changes: 26 additions & 0 deletions src/dtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ function Base.fetch(t::DTask; raw=false)
end
return fetch(t.future; raw)
end
function waitany(tasks::Vector{DTask})
if isempty(tasks)
return
end
cond = Threads.Condition()
for task in tasks
Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin
wait(task)
@lock cond notify(cond)
end)
end
@lock cond wait(cond)
return
end
function waitall(tasks::Vector{DTask})
if isempty(tasks)
return
end
@sync for task in tasks
Threads.@spawn begin
wait(task)
@lock cond notify(cond)
end
end
return
end
function Base.show(io::IO, t::DTask)
status = if istaskstarted(t)
isready(t) ? "finished" : "running"
Expand Down
27 changes: 26 additions & 1 deletion src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,37 @@ function initialize_streaming!(self_streams, spec, task)
end
end

function spawn_streaming(f::Base.Callable)
"""
Starts a streaming region, within which all tasks run continuously and
concurrently. Any `DTask` argument that is itself a streaming task will be
treated as a streaming input/output. The streaming region will automatically
handle the buffering and synchronization of these tasks' values.

# Keyword Arguments
- `teardown::Bool=true`: If `true`, the streaming region will automatically
cancel all tasks if any task fails or is cancelled. Otherwise, a failing task
will not cancel the other tasks, which will continue running.
"""
function spawn_streaming(f::Base.Callable; teardown::Bool=true)
queue = StreamingTaskQueue()
result = with_options(f; task_queue=queue)
if length(queue.tasks) > 0
finalize_streaming!(queue.tasks, queue.self_streams)
enqueue!(queue.tasks)

if teardown
# Start teardown monitor
dtasks = map(last, queue.tasks)::Vector{DTask}
Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin
# Wait for any task to finish
waitany(dtasks)

# Cancel all tasks
for task in dtasks
cancel!(task; graceful=false)
end
end)
end
end
return result
end
Expand Down
112 changes: 112 additions & 0 deletions src/utils/tasks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,115 @@ function set_task_tid!(task::Task, tid::Integer)
end
@assert Threads.threadid(task) == tid "jl_set_task_tid failed!"
end

if isdefined(Base, :waitany)
import Base: waitany, waitall
else
# Vendored from Base
# License is MIT
waitany(tasks; throw=true) = _wait_multiple(tasks, throw)
waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast)
function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
tasks = Task[]

for t in waiting_tasks
t isa Task || error("Expected an iterator of `Task` object")
push!(tasks, t)
end

if (all && !failfast) || length(tasks) <= 1
exception = false
# Force everything to finish synchronously for the case of waitall
# with failfast=false
for t in tasks
_wait(t)
exception |= istaskfailed(t)
end
if exception && throwexc
exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks, Task[]
end
end

exception = false
nremaining::Int = length(tasks)
done_mask = falses(nremaining)
for (i, t) in enumerate(tasks)
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
else
done_mask[i] = false
end
end

if nremaining == 0
return tasks, Task[]
elseif any(done_mask) && (!all || (failfast && exception))
if throwexc && (!all || failfast) && exception
exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks[done_mask], tasks[.~done_mask]
end
end

chan = Channel{Int}(Inf)
sentinel = current_task()
waiter_tasks = fill(sentinel, length(tasks))

for (i, done) in enumerate(done_mask)
done && continue
t = tasks[i]
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
exception && failfast && break
else
waiter = @task put!(chan, i)
waiter.sticky = false
_wait2(t, waiter)
waiter_tasks[i] = waiter
end
end

while nremaining > 0
i = take!(chan)
t = tasks[i]
waiter_tasks[i] = sentinel
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1

# stop early if requested, unless there is something immediately
# ready to consume from the channel (using a race-y check)
if (!all || (failfast && exception)) && !isready(chan)
break
end
end

close(chan)

if nremaining == 0
return tasks, Task[]
else
remaining_mask = .~done_mask
for i in findall(remaining_mask)
waiter = waiter_tasks[i]
donenotify = tasks[i].donenotify::ThreadSynchronizer
@lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter)
end
done_tasks = tasks[done_mask]
if throwexc && exception
exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return done_tasks, tasks[remaining_mask]
end
end
end
end
Loading
Loading