Skip to content

Commit 20fa695

Browse files
committed
make thread 1 interactive when there is an interactive pool, so it can run the event loop
1 parent 5b49c03 commit 20fa695

File tree

4 files changed

+53
-26
lines changed

4 files changed

+53
-26
lines changed

base/task.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ istaskfailed(t::Task) = (load_state_acquire(t) === task_state_failed)
253253
Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1)
254254
function Threads.threadpool(t::Task)
255255
tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), t)
256-
return tpid == 0 ? :default : :interactive
256+
return Threads._tpid_to_sym(tpid)
257257
end
258258

259259
task_result(t::Task) = t.result
@@ -786,7 +786,7 @@ function enq_work(t::Task)
786786
if Threads.threadpoolsize(tp) == 1
787787
# There's only one thread in the task's assigned thread pool;
788788
# use its work queue.
789-
tid = (tp === :default) ? 1 : Threads.threadpoolsize(:default)+1
789+
tid = (tp === :interactive) ? 1 : Threads.threadpoolsize(:interactive)+1
790790
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
791791
push!(workqueue_for(tid), t)
792792
else

base/threadingconstructs.jl

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,22 @@ function _nthreads_in_pool(tpid::Int8)
3939
return Int(unsafe_load(p, tpid + 1))
4040
end
4141

42+
function _tpid_to_sym(tpid::Int8)
43+
return tpid == 0 ? :interactive : :default
44+
end
45+
46+
function _sym_to_tpid(tp::Symbol)
47+
return tp === :interactive ? Int8(0) : Int8(1)
48+
end
49+
4250
"""
4351
Threads.threadpool(tid = threadid()) -> Symbol
4452
4553
Returns the specified thread's threadpool; either `:default` or `:interactive`.
4654
"""
4755
function threadpool(tid = threadid())
4856
tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1)
49-
return tpid == 0 ? :default : :interactive
57+
return _tpid_to_sym(tpid)
5058
end
5159

5260
"""
@@ -67,24 +75,39 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
6775
[`Distributed`](@ref man-distributed) standard library.
6876
"""
6977
function threadpoolsize(pool::Symbol = :default)
70-
if pool === :default
71-
tpid = Int8(0)
72-
elseif pool === :interactive
73-
tpid = Int8(1)
78+
if pool === :default || pool === :interactive
79+
tpid = _sym_to_tpid(pool)
7480
else
7581
error("invalid threadpool specified")
7682
end
7783
return _nthreads_in_pool(tpid)
7884
end
7985

86+
"""
87+
threadpooltids(pool::Symbol)
88+
89+
Returns a vector of IDs of threads in the given pool.
90+
"""
91+
function threadpooltids(pool::Symbol)
92+
ni = _nthreads_in_pool(Int8(0))
93+
if pool === :interactive
94+
return collect(1:ni)
95+
elseif pool === :default
96+
return collect(ni+1:ni+_nthreads_in_pool(Int8(1)))
97+
else
98+
error("invalid threadpool specified")
99+
end
100+
end
101+
80102
function threading_run(fun, static)
81103
ccall(:jl_enter_threaded_region, Cvoid, ())
82104
n = threadpoolsize()
105+
tid_offset = threadpoolsize(:interactive)
83106
tasks = Vector{Task}(undef, n)
84107
for i = 1:n
85108
t = Task(() -> fun(i)) # pass in tid
86109
t.sticky = static
87-
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i-1)
110+
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid_offset + i-1)
88111
tasks[i] = t
89112
schedule(t)
90113
end
@@ -287,6 +310,15 @@ macro threads(args...)
287310
return _threadsfor(ex.args[1], ex.args[2], sched)
288311
end
289312

313+
function _spawn_set_thrpool(t::Task, tp::Symbol)
314+
tpid = _sym_to_tpid(tp)
315+
if _nthreads_in_pool(tpid) == 0
316+
tpid = _sym_to_tpid(:default)
317+
end
318+
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid)
319+
nothing
320+
end
321+
290322
"""
291323
Threads.@spawn [:default|:interactive] expr
292324
@@ -315,7 +347,7 @@ the variable's value in the current task.
315347
A threadpool may be specified as of Julia 1.9.
316348
"""
317349
macro spawn(args...)
318-
tpid = Int8(0)
350+
tp = :default
319351
na = length(args)
320352
if na == 2
321353
ttype, ex = args
@@ -325,9 +357,9 @@ macro spawn(args...)
325357
# TODO: allow unquoted symbols
326358
ttype = nothing
327359
end
328-
if ttype === :interactive
329-
tpid = Int8(1)
330-
elseif ttype !== :default
360+
if ttype === :interactive || ttype === :default
361+
tp = ttype
362+
else
331363
throw(ArgumentError("unsupported threadpool in @spawn: $ttype"))
332364
end
333365
elseif na == 1
@@ -344,11 +376,7 @@ macro spawn(args...)
344376
let $(letargs...)
345377
local task = Task($thunk)
346378
task.sticky = false
347-
local tpid_actual = $tpid
348-
if _nthreads_in_pool(tpid_actual) == 0
349-
tpid_actual = Int8(0)
350-
end
351-
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, tpid_actual)
379+
_spawn_set_thrpool(task, $(QuoteNode(tp)))
352380
if $(Expr(:islocal, var))
353381
put!($var, task)
354382
end

src/threading.c

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -600,17 +600,16 @@ void jl_init_threading(void)
600600
// specified on the command line (and so are in `jl_options`) or by the
601601
// environment variable. Set the globals `jl_n_threadpools`, `jl_n_threads`
602602
// and `jl_n_threads_per_pool`.
603-
jl_n_threadpools = 1;
603+
jl_n_threadpools = 2;
604604
int16_t nthreads = JULIA_NUM_THREADS;
605605
int16_t nthreadsi = 0;
606606
char *endptr, *endptri;
607607

608608
if (jl_options.nthreads != 0) { // --threads specified
609-
jl_n_threadpools = jl_options.nthreadpools;
610609
nthreads = jl_options.nthreads_per_pool[0];
611610
if (nthreads < 0)
612611
nthreads = jl_effective_threads();
613-
if (jl_n_threadpools == 2)
612+
if (jl_options.nthreadpools == 2)
614613
nthreadsi = jl_options.nthreads_per_pool[1];
615614
}
616615
else if ((cp = getenv(NUM_THREADS_NAME))) { // ENV[NUM_THREADS_NAME] specified
@@ -635,15 +634,13 @@ void jl_init_threading(void)
635634
if (errno != 0 || endptri == cp || nthreadsi < 0)
636635
nthreadsi = 0;
637636
}
638-
if (nthreadsi > 0)
639-
jl_n_threadpools++;
640637
}
641638
}
642639

643640
jl_all_tls_states_size = nthreads + nthreadsi;
644641
jl_n_threads_per_pool = (int*)malloc_s(2 * sizeof(int));
645-
jl_n_threads_per_pool[0] = nthreads;
646-
jl_n_threads_per_pool[1] = nthreadsi;
642+
jl_n_threads_per_pool[0] = nthreadsi;
643+
jl_n_threads_per_pool[1] = nthreads;
647644

648645
jl_atomic_store_release(&jl_all_tls_states, (jl_ptls_t*)calloc(jl_all_tls_states_size, sizeof(jl_ptls_t)));
649646
jl_atomic_store_release(&jl_n_threads, jl_all_tls_states_size);

test/threadpool_use.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ using Test
44
using Base.Threads
55

66
@test nthreadpools() == 2
7-
@test threadpool() === :default
8-
@test threadpool(2) === :interactive
7+
@test threadpool() === :interactive
8+
@test threadpool(2) === :default
99
@test fetch(Threads.@spawn Threads.threadpool()) === :default
1010
@test fetch(Threads.@spawn :default Threads.threadpool()) === :default
1111
@test fetch(Threads.@spawn :interactive Threads.threadpool()) === :interactive
12+
@test Threads.threadpooltids(:interactive) == [1]
13+
@test Threads.threadpooltids(:default) == [2]

0 commit comments

Comments
 (0)