Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add waitany and waitall functions to wait multiple tasks at once #53341

Merged
merged 30 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1fb6a95
Add waitany and waitall to wait multiple tasks
mrkn Jan 22, 2024
7482838
Reduce the number of working vectors
mrkn Feb 4, 2024
5a30999
Rewrite with using Channel
mrkn Feb 8, 2024
664f8a3
Add test cases to Cover all argument types
mrkn Feb 15, 2024
8ac6ddc
Remove return type of _wait_multiple
mrkn Feb 27, 2024
f459e8a
Remove 1st argument type of _wait_multiple
mrkn Feb 27, 2024
eb2bafd
Specify type of element that comes from iteration for type stability
mrkn Feb 27, 2024
d7cf9dc
Support inputs that can be iterated only once
mrkn Mar 1, 2024
9159247
Delete waiters from waitq of each remaining task
mrkn Mar 1, 2024
573ee9f
Fix for performance
mrkn Mar 1, 2024
1175779
Split type checking and examining loops
mrkn Mar 4, 2024
633cb58
Stop using needless enumerate
mrkn Mar 5, 2024
4bec8cd
Optimize for waitall with failfast=false
mrkn Mar 5, 2024
b9dd9e6
Use vector for managing waiters
mrkn Mar 5, 2024
ae0ca9d
Add channel emptiness check
mrkn Mar 5, 2024
93057e6
Insert done check in waiter creation loop
mrkn Mar 5, 2024
ed58eda
Stop using kwargs in _wait_multiple
mrkn Mar 6, 2024
f1f400e
Add throw keyword argument in waitall
mrkn Mar 6, 2024
4d8e137
Add throw keyword argument in waitany
mrkn Mar 6, 2024
1a55697
Add docstrings of waitany and waitall
mrkn Mar 7, 2024
22646dd
Wait single task synchronously
mrkn Mar 7, 2024
505d476
Use TaskFailedException
mrkn Mar 7, 2024
1c9adbf
Stop using sleep in test
mrkn Mar 7, 2024
58d1e02
Remove needless yield call
mrkn Mar 7, 2024
3c9a9c8
Wait all three tasks in teardown function in test
mrkn Mar 7, 2024
4dd8862
Use consistent declarative tense in docstring
mrkn Mar 7, 2024
34e3d41
Add waitany and waitall in doc/src/base/parallel.md
mrkn Mar 7, 2024
8a7683f
Add waitany and waitall in NEWS.md
mrkn Mar 7, 2024
d30c9c0
Change default argument values
mrkn Mar 8, 2024
0a382a3
Add usage note of waitall in docstring
mrkn Mar 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ New library functions

* `logrange(start, stop; length)` makes a range of constant ratio, instead of constant step ([#39071])
* The new `isfull(c::Channel)` function can be used to check if `put!(c, some_value)` will block. ([#53159])
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).

New library features
--------------------
Expand Down
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,8 @@ export
yield,
yieldto,
wait,
waitany,
waitall,
timedwait,
asyncmap,
asyncmap!,
Expand Down
142 changes: 142 additions & 0 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,148 @@ function wait(t::Task)
nothing
end

# Wait multiple tasks

"""
waitany(tasks; throw=true) -> (done_tasks, remaining_tasks)

Wait until at least one of the given tasks have been completed.

If `throw` is `true`, throw `CompositeException` when one of the
completed tasks completes with an exception.

The return value consists of two task vectors. The first one consists of
completed tasks, and the other consists of uncompleted tasks.

!!! warning
This may scale poorly compared to writing code that uses multiple individual tasks that
each runs serially, since this needs to scan the list of `tasks` each time and
synchronize with each one every time this is called. Or consider using
[`waitall(tasks; failfast=true)`](@ref waitall) instead.
"""
mrkn marked this conversation as resolved.
Show resolved Hide resolved
waitany(tasks; throw=true) = _wait_multiple(tasks, throw)

"""
waitall(tasks; failfast=true, throw=true) -> (done_tasks, remaining_tasks)

Wait until all the given tasks have been completed.

If `failfast` is `true`, the function will return when at least one of the
given tasks is finished by exception. If `throw` is `true`, throw
`CompositeException` when one of the completed tasks has failed.

`failfast` and `throw` keyword arguments work independently; when only
`throw=true` is specified, this function waits for all the tasks to complete.

The return value consists of two task vectors. The first one consists of
completed tasks, and the other consists of uncompleted tasks.
"""
waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast)

function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
tasks = Task[]

for t in waiting_tasks
t isa Task || error("Expected an iterator of `Task` object")
push!(tasks, t)
end

Comment on lines +406 to +413
Copy link
Contributor Author

@mrkn mrkn Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vtjnash How do you think about splitting _wait_multiple for AbstractVector{Task} and other types to avoid copying a vector?

Suggested change
function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
tasks = Task[]
for t in waiting_tasks
t isa Task || error("Expected an iterator of `Task` object")
push!(tasks, t)
end
function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
tasks = Task[]
for t in waiting_tasks
t isa Task || error("Expected an iterator of `Task` object")
push!(tasks, t)
end
return _wait_multiple(tasks, throwexc, all, failfast)
end
function _wait_multiple(tasks::AbstractVector{Task}, throwexc=false, all=false, failfast=false)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid the copy? That makes sense to me. Though I don't know if all AbstractVector support the bitmasking operations, I think that should be valid.

As food for thought also, in a later PR, I am thinking we should also seek to extend this to allow any other waitable events too, not just Tasks (e.g. Channels, Event, AsyncEvent, Timer, etc.) and either add the throwexc kwarg to those other wait functions (instead of calling it _wait) or rename _wait to something like waitready

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I don't know if all AbstractVector support the bitmasking operations, I think that should be valid.

I guess a[mask] is processed by _getindex(IndexStyle(a), a, LogicalIndex(mask)...), which then reaches _unsafe_getindex defined in multidimensional.jl, so all AbstractVector support the bitmasking operations.

As food for thought also, in a later PR, I am thinking we should also seek to extend this to allow any other waitable events too, not just Tasks

I completely agree. As I wrote in issue #53226, I also thought that supporting waitable objects other than Tasks might be a good idea.

either add the throwexc kwarg to those other wait functions (instead of calling it _wait) or rename _wait to something like waitready

I think the non-throw version might be used as frequently as wait in real-world application development.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I submitted a new PR #53685 for introducing throw option in wait(::Task) function.

if (all && !failfast) || length(tasks) <= 1
exception = false
# Force everything to finish synchronously for the case of waitall
# with failfast=false
for t in tasks
_wait(t)
exception |= istaskfailed(t)
end
if exception && throwexc
exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks, Task[]
end
end

exception = false
nremaining::Int = length(tasks)
done_mask = falses(nremaining)
for (i, t) in enumerate(tasks)
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
else
done_mask[i] = false
end
end

if nremaining == 0
return tasks, Task[]
elseif any(done_mask) && (!all || (failfast && exception))
if throwexc && (!all || failfast) && exception
exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks[done_mask], tasks[.~done_mask]
end
end

chan = Channel{Int}(Inf)
sentinel = current_task()
waiter_tasks = fill(sentinel, length(tasks))

for (i, done) in enumerate(done_mask)
done && continue
t = tasks[i]
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
exception && failfast && break
else
waiter = @task put!(chan, i)
waiter.sticky = false
_wait2(t, waiter)
waiter_tasks[i] = waiter
end
end

while nremaining > 0
i = take!(chan)
t = tasks[i]
waiter_tasks[i] = sentinel
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1

# stop early if requested, unless there is something immediately
# ready to consume from the channel (using a race-y check)
if (!all || (failfast && exception)) && !isready(chan)
break
end
end

close(chan)

if nremaining == 0
return tasks, Task[]
else
remaining_mask = .~done_mask
for i in findall(remaining_mask)
waiter = waiter_tasks[i]
donenotify = tasks[i].donenotify::ThreadSynchronizer
@lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter)
end
done_tasks = tasks[done_mask]
if throwexc && exception
exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return done_tasks, tasks[remaining_mask]
end
end
end

"""
fetch(x::Any)

Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Base.schedule
Base.errormonitor
Base.@sync
Base.wait
Base.waitany
Base.waitall
Base.fetch(t::Task)
Base.fetch(x::Any)
Base.timedwait
Expand Down
123 changes: 123 additions & 0 deletions test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1186,4 +1186,127 @@ end
@testset "threadcall + threads" begin
threadcall_threads() #Shouldn't crash!
end

@testset "Wait multiple tasks" begin
convert_tasks(t, x) = x
convert_tasks(::Set{Task}, x::Vector{Task}) = Set{Task}(x)
convert_tasks(::Tuple{Task}, x::Vector{Task}) = tuple(x...)

function create_tasks()
tasks = Task[]
event = Threads.Event()
push!(tasks,
Threads.@spawn begin
sleep(0.01)
end)
push!(tasks,
Threads.@spawn begin
sleep(0.02)
end)
push!(tasks,
Threads.@spawn begin
wait(event)
end)
return tasks, event
end

function teardown(tasks, event)
notify(event)
waitall(resize!(tasks, 3), throw=true)
end

for tasks_type in (Vector{Task}, Set{Task}, Tuple{Task})
@testset "waitany" begin
@testset "throw=false" begin
tasks, event = create_tasks()
wait(tasks[1])
wait(tasks[2])
done, pending = waitany(convert_tasks(tasks_type, tasks); throw=false)
@test length(done) == 2
@test tasks[1] ∈ done
@test tasks[2] ∈ done
@test length(pending) == 1
@test tasks[3] ∈ pending
teardown(tasks, event)
end

@testset "throw=true" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

@test_throws CompositeException begin
waitany(convert_tasks(tasks_type, tasks); throw=true)
end

teardown(tasks, event)
end
end

@testset "waitall" begin
@testset "All tasks succeed" begin
tasks, event = create_tasks()

wait(tasks[1])
wait(tasks[2])
waiter = Threads.@spawn waitall(convert_tasks(tasks_type, tasks))
@test !istaskdone(waiter)

notify(event)
done, pending = fetch(waiter)
@test length(done) == 3
@test tasks[1] ∈ done
@test tasks[2] ∈ done
@test tasks[3] ∈ done
@test length(pending) == 0
end

@testset "failfast=true, throw=false" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

wait(tasks[1])
wait(tasks[2])
waiter = Threads.@spawn waitall(convert_tasks(tasks_type, tasks); failfast=true, throw=false)

done, pending = fetch(waiter)
@test length(done) == 3
@test tasks[1] ∈ done
@test tasks[2] ∈ done
@test tasks[4] ∈ done
@test length(pending) == 1
@test tasks[3] ∈ pending

teardown(tasks, event)
end

@testset "failfast=false, throw=true" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

notify(event)

@test_throws CompositeException begin
waitall(convert_tasks(tasks_type, tasks); failfast=false, throw=true)
end

@test all(istaskdone.(tasks))

teardown(tasks, event)
end

@testset "failfast=true, throw=true" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

@test_throws CompositeException begin
waitall(convert_tasks(tasks_type, tasks); failfast=true, throw=true)
end

@test !istaskdone(tasks[3])

teardown(tasks, event)
end
end
end
end
end # main testset