Skip to content

Commit

Permalink
make Workqueue threadsafe (#30838)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash authored and JeffBezanson committed Mar 18, 2019
1 parent 956858d commit baf0caa
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 72 deletions.
4 changes: 2 additions & 2 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ eval(Core, :(LineInfoNode(mod::Module, method::Symbol, file::Symbol, line::Int,

Module(name::Symbol=:anonymous, std_imports::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool), name, std_imports)

function Task(@nospecialize(f), reserved_stack::Int=0)
return ccall(:jl_new_task, Ref{Task}, (Any, Int), f, reserved_stack)
function _Task(@nospecialize(f), reserved_stack::Int, completion_future)
return ccall(:jl_new_task, Ref{Task}, (Any, Any, Int), f, completion_future, reserved_stack)
end

# simple convert for use by constructors of types in Core
Expand Down
167 changes: 130 additions & 37 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## basic task functions and TLS

const ThreadSynchronizer = GenericCondition{Threads.SpinLock}
Core.Task(@nospecialize(f), reserved_stack::Int=0) = Core._Task(f, reserved_stack, ThreadSynchronizer())

# Container for a captured exception and its backtrace. Can be serialized.
struct CapturedException <: Exception
ex::Any
Expand Down Expand Up @@ -135,6 +138,8 @@ istaskstarted(t::Task) = ccall(:jl_is_task_started, Cint, (Any,), t) != 0

istaskfailed(t::Task) = (t.state == :failed)

Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1)

task_result(t::Task) = t.result

task_local_storage() = get_task_tls(current_task())
Expand Down Expand Up @@ -181,13 +186,15 @@ end
# NOTE: you can only wait for scheduled tasks
function wait(t::Task)
if !istaskdone(t)
if t.donenotify === nothing
t.donenotify = Condition()
lock(t.donenotify)
try
while !istaskdone(t)
wait(t.donenotify)
end
finally
unlock(t.donenotify)
end
end
while !istaskdone(t)
wait(t.donenotify)
end
if istaskfailed(t)
throw(t.exception)
end
Expand Down Expand Up @@ -273,7 +280,7 @@ end
function register_taskdone_hook(t::Task, hook)
tls = get_task_tls(t)
push!(get!(tls, :TASKDONE_HOOKS, []), hook)
t
return t
end

# runtime system hook called when a task finishes
Expand All @@ -286,9 +293,17 @@ function task_done_hook(t::Task)
t.backtrace = catch_backtrace()
end

if isa(t.donenotify, Condition) && !isempty(t.donenotify.waitq)
handled = true
notify(t.donenotify, result, true, err)
donenotify = t.donenotify
if isa(donenotify, ThreadSynchronizer)
lock(donenotify)
try
if !isempty(donenotify.waitq)
handled = true
notify(donenotify, result, true, err)
end
finally
unlock(donenotify)
end
end

# Execute any other hooks registered in the TLS
Expand All @@ -298,8 +313,8 @@ function task_done_hook(t::Task)
handled = true
end

if err && !handled
if isa(result,InterruptException) && isdefined(Base,:active_repl_backend) &&
if err && !handled && Threads.threadid() == 1
if isa(result, InterruptException) && isdefined(Base, :active_repl_backend) &&
active_repl_backend.backend_task.state == :runnable && isempty(Workqueue) &&
active_repl_backend.in_eval
throwto(active_repl_backend.backend_task, result) # this terminates the task
Expand All @@ -313,7 +328,8 @@ function task_done_hook(t::Task)
# If an InterruptException happens while blocked in the event loop, try handing
# the exception to the REPL task since the current task is done.
# issue #19467
if isa(e,InterruptException) && isdefined(Base,:active_repl_backend) &&
if Threads.threadid() == 1 &&
isa(e, InterruptException) && isdefined(Base, :active_repl_backend) &&
active_repl_backend.backend_task.state == :runnable && isempty(Workqueue) &&
active_repl_backend.in_eval
throwto(active_repl_backend.backend_task, e)
Expand Down Expand Up @@ -360,12 +376,78 @@ end

## scheduler and work queue

global const Workqueue = InvasiveLinkedList{Task}()
struct InvasiveLinkedListSynchronized{T}
queue::InvasiveLinkedList{T}
lock::Threads.SpinLock
InvasiveLinkedListSynchronized{T}() where {T} = new(InvasiveLinkedList{T}(), Threads.SpinLock())
end
isempty(W::InvasiveLinkedListSynchronized) = isempty(W.queue)
length(W::InvasiveLinkedListSynchronized) = length(W.queue)
function push!(W::InvasiveLinkedListSynchronized{T}, t::T) where T
lock(W.lock)
try
push!(W.queue, t)
finally
unlock(W.lock)
end
return W
end
function pushfirst!(W::InvasiveLinkedListSynchronized{T}, t::T) where T
lock(W.lock)
try
pushfirst!(W.queue, t)
finally
unlock(W.lock)
end
return W
end
function pop!(W::InvasiveLinkedListSynchronized)
lock(W.lock)
try
return pop!(W.queue)
finally
unlock(W.lock)
end
end
function popfirst!(W::InvasiveLinkedListSynchronized)
lock(W.lock)
try
return popfirst!(W.queue)
finally
unlock(W.lock)
end
end
function list_deletefirst!(W::InvasiveLinkedListSynchronized{T}, t::T) where T
lock(W.lock)
try
list_deletefirst!(W.queue, t)
finally
unlock(W.lock)
end
return W
end

const StickyWorkqueue = InvasiveLinkedListSynchronized{Task}
global const Workqueues = [StickyWorkqueue()]
global const Workqueue = Workqueues[1] # default work queue is thread 1
function __preinit_threads__()
if length(Workqueues) < Threads.nthreads()
resize!(Workqueues, Threads.nthreads())
for i = 2:length(Workqueues)
Workqueues[i] = StickyWorkqueue()
end
end
nothing
end

function enq_work(t::Task)
(t.state == :runnable && t.queue === nothing) || error("schedule: Task not runnable")
ccall(:uv_stop, Cvoid, (Ptr{Cvoid},), eventloop())
push!(Workqueue, t)
tid = (t.sticky ? Threads.threadid(t) : 0)
if tid == 0
tid = Threads.threadid()
end
push!(Workqueues[tid], t)
tid == 1 && ccall(:uv_stop, Cvoid, (Ptr{Cvoid},), eventloop())
return t
end

Expand Down Expand Up @@ -418,11 +500,12 @@ end
# fast version of `schedule(t, arg); wait()`
function schedule_and_wait(t::Task, @nospecialize(arg)=nothing)
(t.state == :runnable && t.queue === nothing) || error("schedule: Task not runnable")
if isempty(Workqueue)
W = Workqueues[Threads.threadid()]
if isempty(W)
return yieldto(t, arg)
else
t.result = arg
push!(Workqueue, t)
push!(W, t)
end
return wait()
end
Expand Down Expand Up @@ -487,23 +570,24 @@ end

function ensure_rescheduled(othertask::Task)
ct = current_task()
W = Workqueues[Threads.threadid()]
if ct !== othertask && othertask.state == :runnable
# we failed to yield to othertask
# return it to the head of the queue to be scheduled later
pushfirst!(Workqueue, othertask)
end
if ct.queue === Workqueue
# if the current task was queued,
# also need to return it to the runnable state
# before throwing an error
list_deletefirst!(Workqueue, ct)
# return it to the head of a queue to be retried later
tid = Threads.threadid(othertask)
Wother = tid == 0 ? W : Workqueues[tid]
pushfirst!(Wother, othertask)
end
# if the current task was queued,
# also need to return it to the runnable state
# before throwing an error
list_deletefirst!(W, ct)
nothing
end

function trypoptask()
isempty(Workqueue) && return
t = popfirst!(Workqueue)
function trypoptask(W::StickyWorkqueue)
isempty(W) && return
t = popfirst!(W)
if t.state != :runnable
# assume this somehow got queued twice,
# probably broken now, but try discarding this switch and keep going
Expand All @@ -516,25 +600,34 @@ function trypoptask()
return t
end

@noinline function poptaskref()
@noinline function poptaskref(W::StickyWorkqueue)
local task
while true
task = trypoptask()
task = trypoptask(W)
task === nothing || break
if process_events(true) == 0
task = trypoptask()
task === nothing || break
# if there are no active handles and no runnable tasks, just
# wait for signals.
pause()
if !Threads.in_threaded_loop[] && Threads.threadid() == 1
if process_events(true) == 0
task = trypoptask(W)
task === nothing || break
# if there are no active handles and no runnable tasks, just
# wait for signals.
pause()
end
else
if Threads.threadid() == 1
process_events(false)
end
ccall(:jl_gc_safepoint, Cvoid, ())
ccall(:jl_cpu_pause, Cvoid, ())
end
end
return Ref(task)
end


function wait()
reftask = poptaskref()
W = Workqueues[Threads.threadid()]
reftask = poptaskref(W)
result = try_yieldto(ensure_rescheduled, reftask)
process_events(false)
# return when we come out of the queue
Expand Down
7 changes: 4 additions & 3 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,17 @@ function _threadsfor(iter,lbody)
# Hack to make nested threaded loops kinda work
if threadid() != 1 || in_threaded_loop[]
# We are in a nested threaded loop
threadsfor_fun(true)
Base.invokelatest(threadsfor_fun, true)
else
in_threaded_loop[] = true
# the ccall is not expected to throw
ccall(:jl_threading_run, Ref{Cvoid}, (Any,), threadsfor_fun)
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
in_threaded_loop[] = false
end
nothing
end
end

"""
Threads.@threads
Expand All @@ -96,7 +97,7 @@ macro threads(args...)
throw(ArgumentError("need an expression argument to @threads"))
end
if ex.head === :for
return _threadsfor(ex.args[1],ex.args[2])
return _threadsfor(ex.args[1], ex.args[2])
else
throw(ArgumentError("unrecognized argument to @threads"))
end
Expand Down
8 changes: 5 additions & 3 deletions src/gc-debug.c
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,13 @@ static void gc_scrub_task(jl_task_t *ta)
{
int16_t tid = ta->tid;
jl_ptls_t ptls = jl_get_ptls_states();
jl_ptls_t ptls2 = jl_all_tls_states[tid];
jl_ptls_t ptls2 = NULL;
if (tid != -1)
ptls2 = jl_all_tls_states[tid];

char *low;
char *high;
if (ta->copy_stack && ta == ptls2->current_task) {
if (ta->copy_stack && ptls2 && ta == ptls2->current_task) {
low = (char*)ptls2->stackbase - ptls2->stacksize;
high = (char*)ptls2->stackbase;
}
Expand All @@ -593,7 +595,7 @@ static void gc_scrub_task(jl_task_t *ta)
else
return;

if (ptls == ptls2 && ta == ptls2->current_task) {
if (ptls == ptls2 && ptls2 && ta == ptls2->current_task) {
// scan up to current `sp` for current thread and task
low = (char*)jl_get_frame_addr();
}
Expand Down
6 changes: 4 additions & 2 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2331,7 +2331,9 @@ mark: {
gc_scrub_record_task(ta);
void *stkbuf = ta->stkbuf;
int16_t tid = ta->tid;
jl_ptls_t ptls2 = jl_all_tls_states[tid];
jl_ptls_t ptls2 = NULL;
if (tid != -1)
ptls2 = jl_all_tls_states[tid];
if (gc_cblist_task_scanner) {
export_gc_state(ptls, &sp);
gc_invoke_callbacks(jl_gc_cb_task_scanner_t,
Expand All @@ -2347,7 +2349,7 @@ mark: {
uintptr_t offset = 0;
uintptr_t lb = 0;
uintptr_t ub = (uintptr_t)-1;
if (ta == ptls2->current_task) {
if (ptls2 && ta == ptls2->current_task) {
s = ptls2->pgcstack;
}
else if (stkbuf) {
Expand Down
11 changes: 9 additions & 2 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,6 @@ void _julia_init(JL_IMAGE_SEARCH rel)

jl_init_codegen();

jl_start_threads();

jl_an_empty_vec_any = (jl_value_t*)jl_alloc_vec_any(0);
jl_init_serializer();
jl_init_intrinsic_properties();
Expand Down Expand Up @@ -818,7 +816,16 @@ void _julia_init(JL_IMAGE_SEARCH rel)
// it does "using Base" if Base is available.
if (jl_base_module != NULL) {
jl_add_standard_imports(jl_main_module);
// Do initialization needed before starting child threads
jl_value_t *f = jl_get_global(jl_base_module, jl_symbol("__preinit_threads__"));
if (f) {
size_t last_age = ptls->world_age;
ptls->world_age = jl_get_world_counter();
jl_apply(&f, 1);
ptls->world_age = last_age;
}
}
jl_start_threads();

// This needs to be after jl_start_threads
if (jl_options.handle_signals == JL_OPTIONS_HANDLE_SIGNALS_ON)
Expand Down
Loading

0 comments on commit baf0caa

Please sign in to comment.