@@ -39,14 +39,22 @@ function _nthreads_in_pool(tpid::Int8)
39
39
return Int (unsafe_load (p, tpid + 1 ))
40
40
end
41
41
42
+ function _tpid_to_sym (tpid:: Int8 )
43
+ return tpid == 0 ? :interactive : :default
44
+ end
45
+
46
+ function _sym_to_tpid (tp:: Symbol )
47
+ return tp === :interactive ? Int8 (0 ) : Int8 (1 )
48
+ end
49
+
42
50
"""
43
51
Threads.threadpool(tid = threadid()) -> Symbol
44
52
45
53
Returns the specified thread's threadpool; either `:default` or `:interactive`.
46
54
"""
47
55
function threadpool (tid = threadid ())
48
56
tpid = ccall (:jl_threadpoolid , Int8, (Int16,), tid- 1 )
49
- return tpid == 0 ? :default : :interactive
57
+ return _tpid_to_sym ( tpid)
50
58
end
51
59
52
60
"""
@@ -67,24 +75,39 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
67
75
[`Distributed`](@ref man-distributed) standard library.
68
76
"""
69
77
function threadpoolsize (pool:: Symbol = :default )
70
- if pool === :default
71
- tpid = Int8 (0 )
72
- elseif pool === :interactive
73
- tpid = Int8 (1 )
78
+ if pool === :default || pool === :interactive
79
+ tpid = _sym_to_tpid (pool)
74
80
else
75
81
error (" invalid threadpool specified" )
76
82
end
77
83
return _nthreads_in_pool (tpid)
78
84
end
79
85
86
+ """
87
+ threadpooltids(pool::Symbol)
88
+
89
+ Returns a vector of IDs of threads in the given pool.
90
+ """
91
+ function threadpooltids (pool:: Symbol )
92
+ ni = _nthreads_in_pool (Int8 (0 ))
93
+ if pool === :interactive
94
+ return collect (1 : ni)
95
+ elseif pool === :default
96
+ return collect (ni+ 1 : ni+ _nthreads_in_pool (Int8 (1 )))
97
+ else
98
+ error (" invalid threadpool specified" )
99
+ end
100
+ end
101
+
80
102
function threading_run (fun, static)
81
103
ccall (:jl_enter_threaded_region , Cvoid, ())
82
104
n = threadpoolsize ()
105
+ tid_offset = threadpoolsize (:interactive )
83
106
tasks = Vector {Task} (undef, n)
84
107
for i = 1 : n
85
108
t = Task (() -> fun (i)) # pass in tid
86
109
t. sticky = static
87
- static && ccall (:jl_set_task_tid , Cint, (Any, Cint), t, i- 1 )
110
+ static && ccall (:jl_set_task_tid , Cint, (Any, Cint), t, tid_offset + i- 1 )
88
111
tasks[i] = t
89
112
schedule (t)
90
113
end
@@ -287,6 +310,15 @@ macro threads(args...)
287
310
return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
288
311
end
289
312
313
+ function _spawn_set_thrpool (t:: Task , tp:: Symbol )
314
+ tpid = _sym_to_tpid (tp)
315
+ if _nthreads_in_pool (tpid) == 0
316
+ tpid = _sym_to_tpid (:default )
317
+ end
318
+ ccall (:jl_set_task_threadpoolid , Cint, (Any, Int8), t, tpid)
319
+ nothing
320
+ end
321
+
290
322
"""
291
323
Threads.@spawn [:default|:interactive] expr
292
324
@@ -315,7 +347,7 @@ the variable's value in the current task.
315
347
A threadpool may be specified as of Julia 1.9.
316
348
"""
317
349
macro spawn (args... )
318
- tpid = Int8 ( 0 )
350
+ tp = :default
319
351
na = length (args)
320
352
if na == 2
321
353
ttype, ex = args
@@ -325,9 +357,9 @@ macro spawn(args...)
325
357
# TODO : allow unquoted symbols
326
358
ttype = nothing
327
359
end
328
- if ttype === :interactive
329
- tpid = Int8 ( 1 )
330
- elseif ttype != = :default
360
+ if ttype === :interactive || ttype === :default
361
+ tp = ttype
362
+ else
331
363
throw (ArgumentError (" unsupported threadpool in @spawn: $ttype " ))
332
364
end
333
365
elseif na == 1
@@ -344,11 +376,7 @@ macro spawn(args...)
344
376
let $ (letargs... )
345
377
local task = Task ($ thunk)
346
378
task. sticky = false
347
- local tpid_actual = $ tpid
348
- if _nthreads_in_pool (tpid_actual) == 0
349
- tpid_actual = Int8 (0 )
350
- end
351
- ccall (:jl_set_task_threadpoolid , Cint, (Any, Int8), task, tpid_actual)
379
+ _spawn_set_thrpool (task, $ (QuoteNode (tp)))
352
380
if $ (Expr (:islocal , var))
353
381
put! ($ var, task)
354
382
end
0 commit comments