From baf0caa4bbe1da954c1ed84f48385b11a73601fd Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Mon, 18 Mar 2019 15:54:28 -0400 Subject: [PATCH] make Workqueue threadsafe (#30838) --- base/boot.jl | 4 +- base/task.jl | 167 ++++++++++++++++++++++++++++-------- base/threadingconstructs.jl | 7 +- src/gc-debug.c | 8 +- src/gc.c | 6 +- src/init.c | 11 ++- src/julia.h | 5 +- src/task.c | 50 ++++++----- test/threads.jl | 24 +++++- 9 files changed, 210 insertions(+), 72 deletions(-) diff --git a/base/boot.jl b/base/boot.jl index 63c7f4b5b4821..8961fcd751c46 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -380,8 +380,8 @@ eval(Core, :(LineInfoNode(mod::Module, method::Symbol, file::Symbol, line::Int, Module(name::Symbol=:anonymous, std_imports::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool), name, std_imports) -function Task(@nospecialize(f), reserved_stack::Int=0) - return ccall(:jl_new_task, Ref{Task}, (Any, Int), f, reserved_stack) +function _Task(@nospecialize(f), reserved_stack::Int, completion_future) + return ccall(:jl_new_task, Ref{Task}, (Any, Any, Int), f, completion_future, reserved_stack) end # simple convert for use by constructors of types in Core diff --git a/base/task.jl b/base/task.jl index 9a854048fd910..fb73e433db850 100644 --- a/base/task.jl +++ b/base/task.jl @@ -2,6 +2,9 @@ ## basic task functions and TLS +const ThreadSynchronizer = GenericCondition{Threads.SpinLock} +Core.Task(@nospecialize(f), reserved_stack::Int=0) = Core._Task(f, reserved_stack, ThreadSynchronizer()) + # Container for a captured exception and its backtrace. Can be serialized. struct CapturedException <: Exception ex::Any @@ -135,6 +138,8 @@ istaskstarted(t::Task) = ccall(:jl_is_task_started, Cint, (Any,), t) != 0 istaskfailed(t::Task) = (t.state == :failed) +Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1) + task_result(t::Task) = t.result task_local_storage() = get_task_tls(current_task()) @@ -181,13 +186,15 @@ end # NOTE: you can only wait for scheduled tasks function wait(t::Task) if !istaskdone(t) - if t.donenotify === nothing - t.donenotify = Condition() + lock(t.donenotify) + try + while !istaskdone(t) + wait(t.donenotify) + end + finally + unlock(t.donenotify) end end - while !istaskdone(t) - wait(t.donenotify) - end if istaskfailed(t) throw(t.exception) end @@ -273,7 +280,7 @@ end function register_taskdone_hook(t::Task, hook) tls = get_task_tls(t) push!(get!(tls, :TASKDONE_HOOKS, []), hook) - t + return t end # runtime system hook called when a task finishes @@ -286,9 +293,17 @@ function task_done_hook(t::Task) t.backtrace = catch_backtrace() end - if isa(t.donenotify, Condition) && !isempty(t.donenotify.waitq) - handled = true - notify(t.donenotify, result, true, err) + donenotify = t.donenotify + if isa(donenotify, ThreadSynchronizer) + lock(donenotify) + try + if !isempty(donenotify.waitq) + handled = true + notify(donenotify, result, true, err) + end + finally + unlock(donenotify) + end end # Execute any other hooks registered in the TLS @@ -298,8 +313,8 @@ function task_done_hook(t::Task) handled = true end - if err && !handled - if isa(result,InterruptException) && isdefined(Base,:active_repl_backend) && + if err && !handled && Threads.threadid() == 1 + if isa(result, InterruptException) && isdefined(Base, :active_repl_backend) && active_repl_backend.backend_task.state == :runnable && isempty(Workqueue) && active_repl_backend.in_eval throwto(active_repl_backend.backend_task, result) # this terminates the task @@ -313,7 +328,8 @@ function task_done_hook(t::Task) # If an InterruptException happens while blocked in the event loop, try handing # the exception to the REPL task since the current task is done. # issue #19467 - if isa(e,InterruptException) && isdefined(Base,:active_repl_backend) && + if Threads.threadid() == 1 && + isa(e, InterruptException) && isdefined(Base, :active_repl_backend) && active_repl_backend.backend_task.state == :runnable && isempty(Workqueue) && active_repl_backend.in_eval throwto(active_repl_backend.backend_task, e) @@ -360,12 +376,78 @@ end ## scheduler and work queue -global const Workqueue = InvasiveLinkedList{Task}() +struct InvasiveLinkedListSynchronized{T} + queue::InvasiveLinkedList{T} + lock::Threads.SpinLock + InvasiveLinkedListSynchronized{T}() where {T} = new(InvasiveLinkedList{T}(), Threads.SpinLock()) +end +isempty(W::InvasiveLinkedListSynchronized) = isempty(W.queue) +length(W::InvasiveLinkedListSynchronized) = length(W.queue) +function push!(W::InvasiveLinkedListSynchronized{T}, t::T) where T + lock(W.lock) + try + push!(W.queue, t) + finally + unlock(W.lock) + end + return W +end +function pushfirst!(W::InvasiveLinkedListSynchronized{T}, t::T) where T + lock(W.lock) + try + pushfirst!(W.queue, t) + finally + unlock(W.lock) + end + return W +end +function pop!(W::InvasiveLinkedListSynchronized) + lock(W.lock) + try + return pop!(W.queue) + finally + unlock(W.lock) + end +end +function popfirst!(W::InvasiveLinkedListSynchronized) + lock(W.lock) + try + return popfirst!(W.queue) + finally + unlock(W.lock) + end +end +function list_deletefirst!(W::InvasiveLinkedListSynchronized{T}, t::T) where T + lock(W.lock) + try + list_deletefirst!(W.queue, t) + finally + unlock(W.lock) + end + return W +end + +const StickyWorkqueue = InvasiveLinkedListSynchronized{Task} +global const Workqueues = [StickyWorkqueue()] +global const Workqueue = Workqueues[1] # default work queue is thread 1 +function __preinit_threads__() + if length(Workqueues) < Threads.nthreads() + resize!(Workqueues, Threads.nthreads()) + for i = 2:length(Workqueues) + Workqueues[i] = StickyWorkqueue() + end + end + nothing +end function enq_work(t::Task) (t.state == :runnable && t.queue === nothing) || error("schedule: Task not runnable") - ccall(:uv_stop, Cvoid, (Ptr{Cvoid},), eventloop()) - push!(Workqueue, t) + tid = (t.sticky ? Threads.threadid(t) : 0) + if tid == 0 + tid = Threads.threadid() + end + push!(Workqueues[tid], t) + tid == 1 && ccall(:uv_stop, Cvoid, (Ptr{Cvoid},), eventloop()) return t end @@ -418,11 +500,12 @@ end # fast version of `schedule(t, arg); wait()` function schedule_and_wait(t::Task, @nospecialize(arg)=nothing) (t.state == :runnable && t.queue === nothing) || error("schedule: Task not runnable") - if isempty(Workqueue) + W = Workqueues[Threads.threadid()] + if isempty(W) return yieldto(t, arg) else t.result = arg - push!(Workqueue, t) + push!(W, t) end return wait() end @@ -487,23 +570,24 @@ end function ensure_rescheduled(othertask::Task) ct = current_task() + W = Workqueues[Threads.threadid()] if ct !== othertask && othertask.state == :runnable # we failed to yield to othertask - # return it to the head of the queue to be scheduled later - pushfirst!(Workqueue, othertask) - end - if ct.queue === Workqueue - # if the current task was queued, - # also need to return it to the runnable state - # before throwing an error - list_deletefirst!(Workqueue, ct) + # return it to the head of a queue to be retried later + tid = Threads.threadid(othertask) + Wother = tid == 0 ? W : Workqueues[tid] + pushfirst!(Wother, othertask) end + # if the current task was queued, + # also need to return it to the runnable state + # before throwing an error + list_deletefirst!(W, ct) nothing end -function trypoptask() - isempty(Workqueue) && return - t = popfirst!(Workqueue) +function trypoptask(W::StickyWorkqueue) + isempty(W) && return + t = popfirst!(W) if t.state != :runnable # assume this somehow got queued twice, # probably broken now, but try discarding this switch and keep going @@ -516,17 +600,25 @@ function trypoptask() return t end -@noinline function poptaskref() +@noinline function poptaskref(W::StickyWorkqueue) local task while true - task = trypoptask() + task = trypoptask(W) task === nothing || break - if process_events(true) == 0 - task = trypoptask() - task === nothing || break - # if there are no active handles and no runnable tasks, just - # wait for signals. - pause() + if !Threads.in_threaded_loop[] && Threads.threadid() == 1 + if process_events(true) == 0 + task = trypoptask(W) + task === nothing || break + # if there are no active handles and no runnable tasks, just + # wait for signals. + pause() + end + else + if Threads.threadid() == 1 + process_events(false) + end + ccall(:jl_gc_safepoint, Cvoid, ()) + ccall(:jl_cpu_pause, Cvoid, ()) end end return Ref(task) @@ -534,7 +626,8 @@ end function wait() - reftask = poptaskref() + W = Workqueues[Threads.threadid()] + reftask = poptaskref(W) result = try_yieldto(ensure_rescheduled, reftask) process_events(false) # return when we come out of the queue diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index 61a1f598546a6..834fc672a2e1c 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -68,16 +68,17 @@ function _threadsfor(iter,lbody) # Hack to make nested threaded loops kinda work if threadid() != 1 || in_threaded_loop[] # We are in a nested threaded loop - threadsfor_fun(true) + Base.invokelatest(threadsfor_fun, true) else in_threaded_loop[] = true # the ccall is not expected to throw - ccall(:jl_threading_run, Ref{Cvoid}, (Any,), threadsfor_fun) + ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun) in_threaded_loop[] = false end nothing end end + """ Threads.@threads @@ -96,7 +97,7 @@ macro threads(args...) throw(ArgumentError("need an expression argument to @threads")) end if ex.head === :for - return _threadsfor(ex.args[1],ex.args[2]) + return _threadsfor(ex.args[1], ex.args[2]) else throw(ArgumentError("unrecognized argument to @threads")) end diff --git a/src/gc-debug.c b/src/gc-debug.c index 40dc55a4f0550..30e63ca37754b 100644 --- a/src/gc-debug.c +++ b/src/gc-debug.c @@ -578,11 +578,13 @@ static void gc_scrub_task(jl_task_t *ta) { int16_t tid = ta->tid; jl_ptls_t ptls = jl_get_ptls_states(); - jl_ptls_t ptls2 = jl_all_tls_states[tid]; + jl_ptls_t ptls2 = NULL; + if (tid != -1) + ptls2 = jl_all_tls_states[tid]; char *low; char *high; - if (ta->copy_stack && ta == ptls2->current_task) { + if (ta->copy_stack && ptls2 && ta == ptls2->current_task) { low = (char*)ptls2->stackbase - ptls2->stacksize; high = (char*)ptls2->stackbase; } @@ -593,7 +595,7 @@ static void gc_scrub_task(jl_task_t *ta) else return; - if (ptls == ptls2 && ta == ptls2->current_task) { + if (ptls == ptls2 && ptls2 && ta == ptls2->current_task) { // scan up to current `sp` for current thread and task low = (char*)jl_get_frame_addr(); } diff --git a/src/gc.c b/src/gc.c index e5fb03cf95cc0..519d087f3c0b4 100644 --- a/src/gc.c +++ b/src/gc.c @@ -2331,7 +2331,9 @@ mark: { gc_scrub_record_task(ta); void *stkbuf = ta->stkbuf; int16_t tid = ta->tid; - jl_ptls_t ptls2 = jl_all_tls_states[tid]; + jl_ptls_t ptls2 = NULL; + if (tid != -1) + ptls2 = jl_all_tls_states[tid]; if (gc_cblist_task_scanner) { export_gc_state(ptls, &sp); gc_invoke_callbacks(jl_gc_cb_task_scanner_t, @@ -2347,7 +2349,7 @@ mark: { uintptr_t offset = 0; uintptr_t lb = 0; uintptr_t ub = (uintptr_t)-1; - if (ta == ptls2->current_task) { + if (ptls2 && ta == ptls2->current_task) { s = ptls2->pgcstack; } else if (stkbuf) { diff --git a/src/init.c b/src/init.c index eebc96a4540c0..d6475fa51042e 100644 --- a/src/init.c +++ b/src/init.c @@ -761,8 +761,6 @@ void _julia_init(JL_IMAGE_SEARCH rel) jl_init_codegen(); - jl_start_threads(); - jl_an_empty_vec_any = (jl_value_t*)jl_alloc_vec_any(0); jl_init_serializer(); jl_init_intrinsic_properties(); @@ -818,7 +816,16 @@ void _julia_init(JL_IMAGE_SEARCH rel) // it does "using Base" if Base is available. if (jl_base_module != NULL) { jl_add_standard_imports(jl_main_module); + // Do initialization needed before starting child threads + jl_value_t *f = jl_get_global(jl_base_module, jl_symbol("__preinit_threads__")); + if (f) { + size_t last_age = ptls->world_age; + ptls->world_age = jl_get_world_counter(); + jl_apply(&f, 1); + ptls->world_age = last_age; + } } + jl_start_threads(); // This needs to be after jl_start_threads if (jl_options.handle_signals == JL_OPTIONS_HANDLE_SIGNALS_ON) diff --git a/src/julia.h b/src/julia.h index b3d26a2b0b4df..4baa8e3e517bc 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1595,8 +1595,8 @@ JL_DLLEXPORT void jl_sigatomic_end(void); // tasks and exceptions ------------------------------------------------------- - typedef struct _jl_timing_block_t jl_timing_block_t; + // info describing an exception handler typedef struct _jl_handler_t { jl_jmp_buf eh_ctx; @@ -1624,6 +1624,7 @@ typedef struct _jl_task_t { jl_value_t *backtrace; jl_value_t *logstate; jl_function_t *start; + uint8_t sticky; // record whether this Task can be migrated to a new thread // hidden state: jl_ucontext_t ctx; // saved thread state @@ -1651,7 +1652,7 @@ typedef struct _jl_task_t { jl_timing_block_t *timing_stack; } jl_task_t; -JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, size_t ssize); +JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t*, jl_value_t*, size_t); JL_DLLEXPORT void jl_switchto(jl_task_t **pt); JL_DLLEXPORT void JL_NORETURN jl_throw(jl_value_t *e JL_MAYBE_UNROOTED); JL_DLLEXPORT void JL_NORETURN jl_rethrow(void); diff --git a/src/task.c b/src/task.c index 3cd0e22c3a2cd..c95452dbe4bae 100644 --- a/src/task.c +++ b/src/task.c @@ -49,7 +49,7 @@ volatile int jl_in_stackwalk = 0; #endif #endif -// empirically, finish_task needs about 64k stack space to infer/run +// empirically, jl_finish_task needs about 64k stack space to infer/run // and additionally, gc-stack reserves 64k for the guard pages #if defined(MINSIGSTKSZ) && MINSIGSTKSZ > 131072 #define MINSTKSZ MINSIGSTKSZ @@ -101,6 +101,7 @@ static void NOINLINE save_stack(jl_ptls_t ptls, jl_task_t *lastt, jl_task_t **pt } *pt = lastt; // clear the gc-root for the target task before copying the stack for saving lastt->copy_stack = nb; + lastt->sticky = 1; memcpy_a16((uint64_t*)buf, (uint64_t*)frame_addr, nb); // this task's stack could have been modified after // it was marked by an incremental collection @@ -139,7 +140,7 @@ static void restore_stack2(jl_ptls_t ptls, jl_task_t *lastt) static jl_function_t *task_done_hook_func = NULL; -static void JL_NORETURN finish_task(jl_task_t *t, jl_value_t *resultval JL_MAYBE_UNROOTED) +void JL_NORETURN jl_finish_task(jl_task_t *t, jl_value_t *resultval JL_MAYBE_UNROOTED) { jl_ptls_t ptls = jl_get_ptls_states(); JL_SIGATOMIC_BEGIN(); @@ -155,15 +156,7 @@ static void JL_NORETURN finish_task(jl_task_t *t, jl_value_t *resultval JL_MAYBE ptls->in_finalizer = 0; ptls->in_pure_callback = 0; jl_get_ptls_states()->world_age = jl_world_counter; - if (ptls->tid != 0) { - // For now, only thread 0 runs the task scheduler. - // The others return to the thread loop - ptls->root_task->result = jl_nothing; - jl_task_t *task = ptls->root_task; - jl_switchto(&task); - gc_debug_critical_error(); - abort(); - } + // let the runtime know this task is dead and find a new task to run if (task_done_hook_func == NULL) { task_done_hook_func = (jl_function_t*)jl_get_global(jl_base_module, jl_symbol("task_done_hook")); @@ -254,6 +247,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) #ifdef COPY_STACKS // fall back to stack copying if mmap fails t->copy_stack = 1; + t->sticky = 1; t->bufsz = 0; memcpy(&t->ctx, &ptls->base_ctx, sizeof(t->ctx)); #else @@ -276,6 +270,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) if (lastt->copy_stack) { // save the old copy-stack save_stack(ptls, lastt, pt); // allocates (gc-safepoint, and can also fail) if (jl_setjmp(lastt->ctx.uc_mcontext, 0)) { + // TODO: mutex unlock the thread we just switched from #ifdef ENABLE_TIMINGS assert(blk == ptls->current_task->timing_stack); if (blk) @@ -296,6 +291,10 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) ptls->world_age = t->world_age; t->gcstack = NULL; ptls->current_task = t; + if (!lastt->sticky) + // release lastt to run on any tid + lastt->tid = -1; + t->tid = ptls->tid; jl_ucontext_t *lastt_ctx = (killed ? NULL : &lastt->ctx); #ifdef COPY_STACKS @@ -326,6 +325,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) else { jl_start_fiber(lastt_ctx, &t->ctx); } + // TODO: mutex unlock the thread we just switched from #ifdef ENABLE_TIMINGS assert(blk == ptls->current_task->timing_stack); if (blk) @@ -452,7 +452,7 @@ JL_DLLEXPORT void jl_rethrow_other(jl_value_t *e JL_MAYBE_UNROOTED) throw_internal(NULL); } -JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, size_t ssize) +JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize) { jl_ptls_t ptls = jl_get_ptls_states(); jl_task_t *t = (jl_task_t*)jl_gc_alloc(ptls, sizeof(jl_task_t), jl_task_type); @@ -481,18 +481,19 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, size_t ssize) t->state = runnable_sym; t->start = start; t->result = jl_nothing; - t->donenotify = jl_nothing; + t->donenotify = completion_future; t->exception = jl_nothing; t->backtrace = jl_nothing; // Inherit logger state from parent task t->logstate = ptls->current_task->logstate; // there is no active exception handler available on this stack yet t->eh = NULL; - t->tid = 0; + // TODO: allow non-sticky tasks + t->tid = ptls->tid; + t->sticky = 1; t->gcstack = NULL; t->excstack = NULL; t->stkbuf = NULL; - t->tid = 0; t->started = 0; #ifdef ENABLE_TIMINGS t->timing_stack = NULL; @@ -526,7 +527,7 @@ void jl_init_tasks(void) JL_GC_DISABLED NULL, jl_any_type, jl_emptysvec, - jl_perm_symsvec(10, + jl_perm_symsvec(11, "next", "queue", "storage", @@ -536,8 +537,9 @@ void jl_init_tasks(void) JL_GC_DISABLED "exception", "backtrace", "logstate", - "code"), - jl_svec(10, + "code", + "sticky"), + jl_svec(11, jl_any_type, jl_any_type, jl_any_type, @@ -547,7 +549,8 @@ void jl_init_tasks(void) JL_GC_DISABLED jl_any_type, jl_any_type, jl_any_type, - jl_any_type), + jl_any_type, + jl_bool_type), 0, 1, 9); jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_void_type); jl_svecset(jl_task_type->types, 0, listt); @@ -587,7 +590,7 @@ static void NOINLINE JL_NORETURN start_task(void) } skip_pop_exception:; } - finish_task(t, res); + jl_finish_task(t, res); gc_debug_critical_error(); abort(); } @@ -945,6 +948,7 @@ void jl_init_root_task(void *stack_lo, void *stack_hi) ptls->current_task->gcstack = NULL; ptls->current_task->excstack = NULL; ptls->current_task->tid = ptls->tid; + ptls->current_task->sticky = 1; #ifdef JULIA_ENABLE_THREADING arraylist_new(&ptls->current_task->locks, 0); #endif @@ -959,6 +963,12 @@ JL_DLLEXPORT int jl_is_task_started(jl_task_t *t) return t->started; } +JL_DLLEXPORT int16_t jl_get_task_tid(jl_task_t *t) +{ + return t->tid; +} + + #ifdef _OS_WINDOWS_ #if defined(_CPU_X86_) extern DWORD32 __readgsdword(int); diff --git a/test/threads.jl b/test/threads.jl index 884f8e37d8525..3fdcfd791871e 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -6,6 +6,24 @@ using Base.Threads: SpinLock, Mutex # threading constructs +let a = zeros(Int, 2 * nthreads()) + @threads for i = 1:length(a) + @sync begin + @async begin + @async (Libc.systemsleep(1); a[i] += 1) + yield() + a[i] += 1 + end + @async begin + yield() + @async (Libc.systemsleep(1); a[i] += 1) + a[i] += 1 + end + end + end + @test all(isequal(4), a) +end + # parallel loop with parallel atomic addition function threaded_loop(a, r, x) @threads for i in r @@ -434,7 +452,11 @@ function test_thread_cfunction() end @test sum(ok) == 10000 end -test_thread_cfunction() +if nthreads() == 1 + test_thread_cfunction() +else + @test_broken "cfunction trampoline code not thread-safe" +end # Compare the two ways of checking if threading is enabled. # `jl_tls_states` should only be defined on non-threading build.