Skip to content

Commit

Permalink
Channel: upgrade to make threadsafe internally (JuliaLang#30186)
Browse files Browse the repository at this point in the history
This drops the internal `notify_error` API, use `close` instead.
  • Loading branch information
vtjnash authored Jan 7, 2019
1 parent 699163f commit 9619079
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 127 deletions.
220 changes: 118 additions & 102 deletions base/channels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,30 @@ Other constructors:
* `Channel(sz)`: equivalent to `Channel{Any}(sz)`
"""
mutable struct Channel{T} <: AbstractChannel{T}
cond_take::Condition # waiting for data to become available
cond_put::Condition # waiting for a writeable slot
cond_take::Threads.Condition # waiting for data to become available
cond_wait::Threads.Condition # waiting for data to become maybe available
cond_put::Threads.Condition # waiting for a writeable slot
state::Symbol
excp::Union{Exception, Nothing} # exception to be thrown when state != :open
excp::Union{Exception, Nothing} # exception to be thrown when state != :open

data::Vector{T}
sz_max::Int # maximum size of channel

# Used when sz_max == 0, i.e., an unbuffered channel.
waiters::Int
takers::Vector{Task}
putters::Vector{Task}

function Channel{T}(sz::Float64) where T
if sz == Inf
Channel{T}(typemax(Int))
else
Channel{T}(convert(Int, sz))
end
end
function Channel{T}(sz::Integer) where T
if sz < 0
throw(ArgumentError("Channel size must be either 0, a positive integer or Inf"))
end
ch = new(Condition(), Condition(), :open, nothing, Vector{T}(), sz, 0)
if sz == 0
ch.takers = Vector{Task}()
ch.putters = Vector{Task}()
end
return ch
lock = ReentrantLock()
cond_put, cond_take = Threads.Condition(lock), Threads.Condition(lock)
cond_wait = (sz == 0 ? Threads.Condition(lock) : cond_take) # wait is distinct from take iff unbuffered
return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), sz)
end
end

function Channel{T}(sz::Float64) where T
sz = (sz == Inf ? typemax(Int) : convert(Int, sz))
return Channel{T}(sz)
end
Channel(sz) = Channel{Any}(sz)

# special constructors
Expand Down Expand Up @@ -122,22 +113,30 @@ isbuffered(c::Channel) = c.sz_max==0 ? false : true

function check_channel_state(c::Channel)
if !isopen(c)
c.excp !== nothing && throw(c.excp)
excp = c.excp
excp !== nothing && throw(excp)
throw(closed_exception())
end
end
"""
close(c::Channel)
close(c::Channel[, excp::Exception])
Close a channel. An exception is thrown by:
Close a channel. An exception (optionally given by `excp`), is thrown by:
* [`put!`](@ref) on a closed channel.
* [`take!`](@ref) and [`fetch`](@ref) on an empty, closed channel.
"""
function close(c::Channel)
c.state = :closed
c.excp = closed_exception()
notify_error(c)
function close(c::Channel, excp::Exception=closed_exception())
lock(c)
try
c.state = :closed
c.excp = excp
notify_error(c.cond_take, excp)
notify_error(c.cond_wait, excp)
notify_error(c.cond_put, excp)
finally
unlock(c)
end
nothing
end
isopen(c::Channel) = (c.state == :open)
Expand Down Expand Up @@ -195,7 +194,7 @@ Stacktrace:
function bind(c::Channel, task::Task)
ref = WeakRef(c)
register_taskdone_hook(task, tsk->close_chnl_on_taskdone(tsk, ref))
c
return c
end

"""
Expand Down Expand Up @@ -225,17 +224,34 @@ function channeled_tasks(n::Int, funcs...; ctypes=fill(Any,n), csizes=fill(0,n))
end

function close_chnl_on_taskdone(t::Task, ref::WeakRef)
if ref.value !== nothing
c = ref.value
!isopen(c) && return
if istaskfailed(t)
c.state = :closed
c.excp = task_result(t)
notify_error(c)
c = ref.value
if c isa Channel
isopen(c) || return
cleanup = () -> try
isopen(c) || return
if istaskfailed(t)
excp = task_result(t)
if excp isa Exception
close(c, excp)
return
end
end
close(c)
return
finally
unlock(c)
end
if trylock(c)
# can't use `lock`, since attempts to task-switch to wait for it
# will just silently fail and leave us with broken state
cleanup()
else
close(c)
# so schedule this to happen once we are finished destroying our task
# (on a new Task)
@async (lock(c); cleanup())
end
end
nothing
end

struct InvalidStateException <: Exception
Expand All @@ -257,33 +273,39 @@ task.
function put!(c::Channel{T}, v) where T
check_channel_state(c)
v = convert(T, v)
isbuffered(c) ? put_buffered(c,v) : put_unbuffered(c,v)
return isbuffered(c) ? put_buffered(c, v) : put_unbuffered(c, v)
end

function put_buffered(c::Channel, v)
while length(c.data) == c.sz_max
wait(c.cond_put)
lock(c)
try
while length(c.data) == c.sz_max
check_channel_state(c)
wait(c.cond_put)
end
push!(c.data, v)
# notify all, since some of the waiters may be on a "fetch" call.
notify(c.cond_take, nothing, true, false)
finally
unlock(c)
end
push!(c.data, v)

# notify all, since some of the waiters may be on a "fetch" call.
notify(c.cond_take, nothing, true, false)
v
return v
end

function put_unbuffered(c::Channel, v)
if length(c.takers) == 0
push!(c.putters, current_task())
c.waiters > 0 && notify(c.cond_take, nothing, false, false)

try
wait()
catch
filter!(x->x!=current_task(), c.putters)
rethrow()
lock(c)
taker = try
while isempty(c.cond_take.waitq)
check_channel_state(c)
notify(c.cond_wait)
wait(c.cond_put)
end
# unfair scheduled version of: notify(c.cond_take, v, false, false); yield()
popfirst!(c.cond_take.waitq)
finally
unlock(c)
end
taker = popfirst!(c.takers)
# unfair version of: schedule(taker, v); yield()
yield(taker, v) # immediately give taker a chance to run, but don't block the current task
return v
end
Expand All @@ -298,8 +320,16 @@ remove the item. `fetch` is unsupported on an unbuffered (0-size) channel.
"""
fetch(c::Channel) = isbuffered(c) ? fetch_buffered(c) : fetch_unbuffered(c)
function fetch_buffered(c::Channel)
wait(c)
c.data[1]
lock(c)
try
while isempty(c.data)
check_channel_state(c)
wait(c.cond_take)
end
return c.data[1]
finally
unlock(c)
end
end
fetch_unbuffered(c::Channel) = throw(ErrorException("`fetch` is not supported on an unbuffered Channel."))

Expand All @@ -314,32 +344,31 @@ task.
"""
take!(c::Channel) = isbuffered(c) ? take_buffered(c) : take_unbuffered(c)
function take_buffered(c::Channel)
wait(c)
v = popfirst!(c.data)
notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!.
v
lock(c)
try
while isempty(c.data)
check_channel_state(c)
wait(c.cond_take)
end
v = popfirst!(c.data)
notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!.
return v
finally
unlock(c)
end
end

popfirst!(c::Channel) = take!(c)

# 0-size channel
function take_unbuffered(c::Channel{T}) where T
check_channel_state(c)
push!(c.takers, current_task())
lock(c)
try
if length(c.putters) > 0
let refputter = Ref(popfirst!(c.putters))
return Base.try_yieldto(refputter) do putter
# if we fail to start putter, put it back in the queue
putter === current_task || pushfirst!(c.putters, putter)
end::T
end
else
return wait()::T
end
catch
filter!(x->x!=current_task(), c.takers)
rethrow()
check_channel_state(c)
notify(c.cond_put, nothing, false, false)
return wait(c.cond_take)::T
finally
unlock(c)
end
end

Expand All @@ -353,39 +382,26 @@ For unbuffered channels returns `true` if there are tasks waiting
on a [`put!`](@ref).
"""
isready(c::Channel) = n_avail(c) > 0
n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.putters)
n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.cond_put.waitq)

wait(c::Channel) = isbuffered(c) ? wait_impl(c) : wait_unbuffered(c)
function wait_impl(c::Channel)
while !isready(c)
check_channel_state(c)
wait(c.cond_take)
end
nothing
end
lock(c::Channel) = lock(c.cond_take)
unlock(c::Channel) = unlock(c.cond_take)
trylock(c::Channel) = trylock(c.cond_take)

function wait_unbuffered(c::Channel)
c.waiters += 1
function wait(c::Channel)
isready(c) && return
lock(c)
try
wait_impl(c)
while !isready(c)
check_channel_state(c)
wait(c.cond_wait)
end
finally
c.waiters -= 1
unlock(c)
end
nothing
end

function notify_error(c::Channel, err)
notify_error(c.cond_take, err)
notify_error(c.cond_put, err)

# release tasks on a `wait()/yieldto()` call (on unbuffered channels)
if !isbuffered(c)
waiters = filter!(t->(t.state == :runnable), vcat(c.takers, c.putters))
foreach(t->schedule(t, err; error=true), waiters)
end
end
notify_error(c::Channel) = notify_error(c, c.excp)

eltype(::Type{Channel{T}}) where {T} = T

show(io::IO, c::Channel) = print(io, "$(typeof(c))(sz_max:$(c.sz_max),sz_curr:$(n_avail(c)))")
Expand All @@ -394,7 +410,7 @@ function iterate(c::Channel, state=nothing)
try
return (take!(c), nothing)
catch e
if isa(e, InvalidStateException) && e.state==:closed
if isa(e, InvalidStateException) && e.state == :closed
return nothing
else
rethrow()
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Distributed/src/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Base: getindex, wait, put!, take!, fetch, isready, push!, length,

# imports for use
using Base: Process, Semaphore, JLOptions, AnyDict, buffer_writes, wait_connected,
VERSION_STRING, binding_module, notify_error, atexit, julia_exename,
VERSION_STRING, binding_module, atexit, julia_exename,
julia_cmd, AsyncGenerator, acquire, release, invokelatest,
shell_escape_posixly, uv_error, something, notnothing, isbuffered

Expand Down
12 changes: 7 additions & 5 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1083,24 +1083,25 @@ function deregister_worker(pg, pid)
ids = []
tonotify = []
lock(client_refs) do
for (id,rv) in pg.refs
if in(pid,rv.clientset)
for (id, rv) in pg.refs
if in(pid, rv.clientset)
push!(ids, id)
end
if rv.waitingfor == pid
push!(tonotify, (id,rv))
push!(tonotify, (id, rv))
end
end
for id in ids
del_client(pg, id, pid)
end

# throw exception to tasks waiting for this pid
for (id,rv) in tonotify
notify_error(rv.c, ProcessExitedException())
for (id, rv) in tonotify
close(rv.c, ProcessExitedException())
delete!(pg.refs, id)
end
end
return
end


Expand All @@ -1110,6 +1111,7 @@ function interrupt(pid::Integer)
if isa(w, Worker)
manage(w.manager, w.id, w.config, :interrupt)
end
return
end

"""
Expand Down
Loading

0 comments on commit 9619079

Please sign in to comment.