Skip to content

Refactor scheduler and implement spinner thread for Partr. #56475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ const liblapack_name = libblas_name
# Note that `atomics.jl` here should be deprecated
Core.eval(Threads, :(include("atomics.jl")))
include("channels.jl")
include("partr.jl")
include("scheduler/scheduler.jl")
include("task.jl")
include("threads_overloads.jl")
include("weakkeydict.jl")
Expand Down
73 changes: 12 additions & 61 deletions base/partr.jl → base/scheduler/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,63 +19,6 @@ const heap_d = UInt32(8)
const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
const heaps_lock = [SpinLock(), SpinLock()]


"""
cong(max::UInt32)

Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check

get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())

set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)

"""
rand_ptls(max::UInt32)

Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
rngseed = get_ptls_rng()
val, seed = rand_uniform_max_int32(max, rngseed)
set_ptls_rng(seed)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implementation as well.
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)

Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral part
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = prod >> 32 % UInt32 # integral part
return i % UInt32, seed
end



function multiq_sift_up(heap::taskheap, idx::Int32)
while idx > Int32(1)
parent = (idx - Int32(2)) ÷ heap_d + Int32(1)
Expand Down Expand Up @@ -147,10 +90,10 @@ function multiq_insert(task::Task, priority::UInt16)

task.priority = priority

rn = cong(heap_p)
rn = Base.Scheduler.cong(heap_p)
tpheaps = heaps[tp]
while !trylock(tpheaps[rn].lock)
rn = cong(heap_p)
rn = Base.Scheduler.cong(heap_p)
end

heap = tpheaps[rn]
Expand Down Expand Up @@ -190,8 +133,8 @@ function multiq_deletemin()
if i == heap_p
return nothing
end
rn1 = cong(heap_p)
rn2 = cong(heap_p)
rn1 = Base.Scheduler.cong(heap_p)
rn2 = Base.Scheduler.cong(heap_p)
prio1 = tpheaps[rn1].priority
prio2 = tpheaps[rn2].priority
if prio1 > prio2
Expand Down Expand Up @@ -235,6 +178,9 @@ function multiq_check_empty()
if tp == 0 # Foreign thread
return true
end
if !isempty(Base.workqueue_for(tid))
return false
end
for i = UInt32(1):length(heaps[tp])
if heaps[tp][i].ntasks != 0
return false
Expand All @@ -243,4 +189,9 @@ function multiq_check_empty()
return true
end


enqueue!(t::Task) = multiq_insert(t, t.priority)
dequeue!() = multiq_deletemin()
checktaskempty() = multiq_check_empty()

end
74 changes: 74 additions & 0 deletions base/scheduler/scheduler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module Scheduler

"""
cong(max::UInt32)

Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check

get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())

set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)

"""
rand_ptls(max::UInt32)

Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
rngseed = get_ptls_rng()
val, seed = rand_uniform_max_int32(max, rngseed)
set_ptls_rng(seed)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implementation as well.
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)

Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral part
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = prod >> 32 % UInt32 # integral part
return i % UInt32, seed
end

include("scheduler/partr.jl")

const ChosenScheduler = Partr



# Scheduler interface:
# enqueue! which pushes a runnable Task into it
# dequeue! which pops a runnable Task from it
# checktaskempty which returns true if the scheduler has no available Tasks

enqueue!(t::Task) = ChosenScheduler.enqueue!(t)
dequeue!() = ChosenScheduler.dequeue!()
checktaskempty() = ChosenScheduler.checktaskempty()

end
34 changes: 27 additions & 7 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,6 @@ end

function enq_work(t::Task)
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")

# Sticky tasks go into their thread's work queue.
if t.sticky
tid = Threads.threadid(t)
Expand Down Expand Up @@ -968,19 +967,40 @@ function enq_work(t::Task)
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
push!(workqueue_for(tid), t)
else
# Otherwise, put the task in the multiqueue.
Partr.multiq_insert(t, t.priority)
# Otherwise, push the task to the scheduler
Scheduler.enqueue!(t)
tid = 0
end
end
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)

if (tid == 0)
ccall(:jl_wake_any_thread, Cvoid, (Any,), current_task())
else
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
end
return t
end

const ChildFirst = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we for now not add ChildFirst?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's not even correct for now.


function schedule(t::Task)
# [task] created -scheduled-> wait_time
maybe_record_enqueued!(t)
enq_work(t)
if ChildFirst
ct = current_task()
if ct.sticky || t.sticky
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should actually check if set_task_tid succeeded so that this isn't a data race here (even though this is dead code right now)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually any use of yieldto seems to have this problem, so maybe it deserves another look

maybe_record_enqueued!(t)
enq_work(t)
else
maybe_record_enqueued!(t)
enq_work(ct)
yieldto(t)
end
else
maybe_record_enqueued!(t)
enq_work(t)
end
return t
end

"""
Expand Down Expand Up @@ -1186,10 +1206,10 @@ function trypoptask(W::StickyWorkqueue)
end
return t
end
return Partr.multiq_deletemin()
return Scheduler.dequeue!()
end

checktaskempty = Partr.multiq_check_empty
checktaskempty = Scheduler.checktaskempty

function wait()
ct = current_task()
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@
XX(jl_tagged_gensym) \
XX(jl_take_buffer) \
XX(jl_task_get_next) \
XX(jl_wake_any_thread) \
XX(jl_termios_size) \
XX(jl_test_cpu_feature) \
XX(jl_threadid) \
Expand Down
1 change: 1 addition & 0 deletions src/julia_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ typedef struct _jl_tls_states_t {
uint64_t uv_run_leave;
uint64_t sleep_enter;
uint64_t sleep_leave;
uint64_t woken_up;
)

// some hidden state (usually just because we don't have the type's size declaration)
Expand Down
Loading