Skip to content

make thread 1 interactive when there is an interactive pool, so it can run the event loop #49094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ istaskfailed(t::Task) = (load_state_acquire(t) === task_state_failed)
Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1)
function Threads.threadpool(t::Task)
tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), t)
return tpid == 0 ? :default : :interactive
return Threads._tpid_to_sym(tpid)
end

task_result(t::Task) = t.result
Expand Down Expand Up @@ -786,7 +786,7 @@ function enq_work(t::Task)
if Threads.threadpoolsize(tp) == 1
# There's only one thread in the task's assigned thread pool;
# use its work queue.
tid = (tp === :default) ? 1 : Threads.threadpoolsize(:default)+1
tid = (tp === :interactive) ? 1 : Threads.threadpoolsize(:interactive)+1
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
push!(workqueue_for(tid), t)
else
Expand Down
58 changes: 43 additions & 15 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,22 @@ function _nthreads_in_pool(tpid::Int8)
return Int(unsafe_load(p, tpid + 1))
end

function _tpid_to_sym(tpid::Int8)
return tpid == 0 ? :interactive : :default
end

function _sym_to_tpid(tp::Symbol)
return tp === :interactive ? Int8(0) : Int8(1)
end

"""
Threads.threadpool(tid = threadid()) -> Symbol

Returns the specified thread's threadpool; either `:default` or `:interactive`.
"""
function threadpool(tid = threadid())
tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1)
return tpid == 0 ? :default : :interactive
return _tpid_to_sym(tpid)
end

"""
Expand All @@ -67,24 +75,39 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
[`Distributed`](@ref man-distributed) standard library.
"""
function threadpoolsize(pool::Symbol = :default)
if pool === :default
tpid = Int8(0)
elseif pool === :interactive
tpid = Int8(1)
if pool === :default || pool === :interactive
tpid = _sym_to_tpid(pool)
else
error("invalid threadpool specified")
end
return _nthreads_in_pool(tpid)
end

"""
threadpooltids(pool::Symbol)

Returns a vector of IDs of threads in the given pool.
"""
function threadpooltids(pool::Symbol)
ni = _nthreads_in_pool(Int8(0))
if pool === :interactive
return collect(1:ni)
elseif pool === :default
return collect(ni+1:ni+_nthreads_in_pool(Int8(1)))
else
error("invalid threadpool specified")
end
end

function threading_run(fun, static)
ccall(:jl_enter_threaded_region, Cvoid, ())
n = threadpoolsize()
tid_offset = threadpoolsize(:interactive)
tasks = Vector{Task}(undef, n)
for i = 1:n
t = Task(() -> fun(i)) # pass in tid
t.sticky = static
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i-1)
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid_offset + i-1)
tasks[i] = t
schedule(t)
end
Expand Down Expand Up @@ -287,6 +310,15 @@ macro threads(args...)
return _threadsfor(ex.args[1], ex.args[2], sched)
end

function _spawn_set_thrpool(t::Task, tp::Symbol)
tpid = _sym_to_tpid(tp)
if _nthreads_in_pool(tpid) == 0
tpid = _sym_to_tpid(:default)
end
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid)
nothing
end

"""
Threads.@spawn [:default|:interactive] expr

Expand Down Expand Up @@ -315,7 +347,7 @@ the variable's value in the current task.
A threadpool may be specified as of Julia 1.9.
"""
macro spawn(args...)
tpid = Int8(0)
tp = :default
na = length(args)
if na == 2
ttype, ex = args
Expand All @@ -325,9 +357,9 @@ macro spawn(args...)
# TODO: allow unquoted symbols
ttype = nothing
end
if ttype === :interactive
tpid = Int8(1)
elseif ttype !== :default
if ttype === :interactive || ttype === :default
tp = ttype
else
throw(ArgumentError("unsupported threadpool in @spawn: $ttype"))
end
elseif na == 1
Expand All @@ -344,11 +376,7 @@ macro spawn(args...)
let $(letargs...)
local task = Task($thunk)
task.sticky = false
local tpid_actual = $tpid
if _nthreads_in_pool(tpid_actual) == 0
tpid_actual = Int8(0)
end
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, tpid_actual)
_spawn_set_thrpool(task, $(QuoteNode(tp)))
if $(Expr(:islocal, var))
put!($var, task)
end
Expand Down
11 changes: 4 additions & 7 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -600,17 +600,16 @@ void jl_init_threading(void)
// specified on the command line (and so are in `jl_options`) or by the
// environment variable. Set the globals `jl_n_threadpools`, `jl_n_threads`
// and `jl_n_threads_per_pool`.
jl_n_threadpools = 1;
jl_n_threadpools = 2;
int16_t nthreads = JULIA_NUM_THREADS;
int16_t nthreadsi = 0;
char *endptr, *endptri;

if (jl_options.nthreads != 0) { // --threads specified
jl_n_threadpools = jl_options.nthreadpools;
nthreads = jl_options.nthreads_per_pool[0];
if (nthreads < 0)
nthreads = jl_effective_threads();
if (jl_n_threadpools == 2)
if (jl_options.nthreadpools == 2)
nthreadsi = jl_options.nthreads_per_pool[1];
}
else if ((cp = getenv(NUM_THREADS_NAME))) { // ENV[NUM_THREADS_NAME] specified
Expand All @@ -635,15 +634,13 @@ void jl_init_threading(void)
if (errno != 0 || endptri == cp || nthreadsi < 0)
nthreadsi = 0;
}
if (nthreadsi > 0)
jl_n_threadpools++;
}
}

jl_all_tls_states_size = nthreads + nthreadsi;
jl_n_threads_per_pool = (int*)malloc_s(2 * sizeof(int));
jl_n_threads_per_pool[0] = nthreads;
jl_n_threads_per_pool[1] = nthreadsi;
jl_n_threads_per_pool[0] = nthreadsi;
jl_n_threads_per_pool[1] = nthreads;

jl_atomic_store_release(&jl_all_tls_states, (jl_ptls_t*)calloc(jl_all_tls_states_size, sizeof(jl_ptls_t)));
jl_atomic_store_release(&jl_n_threads, jl_all_tls_states_size);
Expand Down
6 changes: 4 additions & 2 deletions test/threadpool_use.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ using Test
using Base.Threads

@test nthreadpools() == 2
@test threadpool() === :default
@test threadpool(2) === :interactive
@test threadpool() === :interactive
@test threadpool(2) === :default
@test fetch(Threads.@spawn Threads.threadpool()) === :default
@test fetch(Threads.@spawn :default Threads.threadpool()) === :default
@test fetch(Threads.@spawn :interactive Threads.threadpool()) === :interactive
@test Threads.threadpooltids(:interactive) == [1]
@test Threads.threadpooltids(:default) == [2]