Skip to content

Commit e3255ef

Browse files
JeffBezansonvtjnash
authored andcommitted
fix #41546, make using thread-safe (#41602)
use more precision when handling loading lock, merge with TOML lock (since we typically are needing both, sometimes in unpredictable orders), and unlock before call most user code Co-authored-by: Jameson Nash <vtjnash@gmail.com> (cherry picked from commit 3d4b213)
1 parent 532b3f8 commit e3255ef

File tree

4 files changed

+84
-27
lines changed

4 files changed

+84
-27
lines changed

base/loading.jl

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
# Base.require is the implementation for the `import` statement
4+
const require_lock = ReentrantLock()
45

56
# Cross-platform case-sensitive path canonicalization
67

@@ -129,6 +130,7 @@ end
129130
const ns_dummy_uuid = UUID("fe0723d6-3a44-4c41-8065-ee0f42c8ceab")
130131

131132
function dummy_uuid(project_file::String)
133+
@lock require_lock begin
132134
cache = LOADING_CACHE[]
133135
if cache !== nothing
134136
uuid = get(cache.dummy_uuid, project_file, nothing)
@@ -144,6 +146,7 @@ function dummy_uuid(project_file::String)
144146
cache.dummy_uuid[project_file] = uuid
145147
end
146148
return uuid
149+
end
147150
end
148151

149152
## package path slugs: turning UUID + SHA1 into a pair of 4-byte "slugs" ##
@@ -236,8 +239,7 @@ struct TOMLCache
236239
end
237240
const TOML_CACHE = TOMLCache(TOML.Parser(), Dict{String, Dict{String, Any}}())
238241

239-
const TOML_LOCK = ReentrantLock()
240-
parsed_toml(project_file::AbstractString) = parsed_toml(project_file, TOML_CACHE, TOML_LOCK)
242+
parsed_toml(project_file::AbstractString) = parsed_toml(project_file, TOML_CACHE, require_lock)
241243
function parsed_toml(project_file::AbstractString, toml_cache::TOMLCache, toml_lock::ReentrantLock)
242244
lock(toml_lock) do
243245
cache = LOADING_CACHE[]
@@ -337,13 +339,15 @@ Use [`dirname`](@ref) to get the directory part and [`basename`](@ref)
337339
to get the file name part of the path.
338340
"""
339341
function pathof(m::Module)
340-
pkgid = get(Base.module_keys, m, nothing)
342+
@lock require_lock begin
343+
pkgid = get(module_keys, m, nothing)
341344
pkgid === nothing && return nothing
342-
origin = get(Base.pkgorigins, pkgid, nothing)
345+
origin = get(pkgorigins, pkgid, nothing)
343346
origin === nothing && return nothing
344347
path = origin.path
345348
path === nothing && return nothing
346349
return fixup_stdlib_path(path)
350+
end
347351
end
348352

349353
"""
@@ -366,7 +370,7 @@ julia> pkgdir(Foo, "src", "file.jl")
366370
The optional argument `paths` requires at least Julia 1.7.
367371
"""
368372
function pkgdir(m::Module, paths::String...)
369-
rootmodule = Base.moduleroot(m)
373+
rootmodule = moduleroot(m)
370374
path = pathof(rootmodule)
371375
path === nothing && return nothing
372376
return joinpath(dirname(dirname(path)), paths...)
@@ -383,6 +387,7 @@ const preferences_names = ("JuliaLocalPreferences.toml", "LocalPreferences.toml"
383387
# - `true`: `env` is an implicit environment
384388
# - `path`: the path of an explicit project file
385389
function env_project_file(env::String)::Union{Bool,String}
390+
@lock require_lock begin
386391
cache = LOADING_CACHE[]
387392
if cache !== nothing
388393
project_file = get(cache.env_project_file, env, nothing)
@@ -406,6 +411,7 @@ function env_project_file(env::String)::Union{Bool,String}
406411
cache.env_project_file[env] = project_file
407412
end
408413
return project_file
414+
end
409415
end
410416

411417
function project_deps_get(env::String, name::String)::Union{Nothing,PkgId}
@@ -473,6 +479,7 @@ end
473479

474480
# find project file's corresponding manifest file
475481
function project_file_manifest_path(project_file::String)::Union{Nothing,String}
482+
@lock require_lock begin
476483
cache = LOADING_CACHE[]
477484
if cache !== nothing
478485
manifest_path = get(cache.project_file_manifest_path, project_file, missing)
@@ -501,6 +508,7 @@ function project_file_manifest_path(project_file::String)::Union{Nothing,String}
501508
cache.project_file_manifest_path[project_file] = manifest_path
502509
end
503510
return manifest_path
511+
end
504512
end
505513

506514
# given a directory (implicit env from LOAD_PATH) and a name,
@@ -688,7 +696,7 @@ function implicit_manifest_deps_get(dir::String, where::PkgId, name::String)::Un
688696
@assert where.uuid !== nothing
689697
project_file = entry_point_and_project_file(dir, where.name)[2]
690698
project_file === nothing && return nothing # a project file is mandatory for a package with a uuid
691-
proj = project_file_name_uuid(project_file, where.name, )
699+
proj = project_file_name_uuid(project_file, where.name)
692700
proj == where || return nothing # verify that this is the correct project file
693701
# this is the correct project, so stop searching here
694702
pkg_uuid = explicit_project_deps_get(project_file, name)
@@ -753,19 +761,26 @@ function _include_from_serialized(path::String, depmods::Vector{Any})
753761
if isa(sv, Exception)
754762
return sv
755763
end
756-
restored = sv[1]
757-
if !isa(restored, Exception)
758-
for M in restored::Vector{Any}
759-
M = M::Module
760-
if isdefined(M, Base.Docs.META)
761-
push!(Base.Docs.modules, M)
762-
end
763-
if parentmodule(M) === M
764-
register_root_module(M)
765-
end
764+
sv = sv::SimpleVector
765+
restored = sv[1]::Vector{Any}
766+
for M in restored
767+
M = M::Module
768+
if isdefined(M, Base.Docs.META)
769+
push!(Base.Docs.modules, M)
770+
end
771+
if parentmodule(M) === M
772+
register_root_module(M)
773+
end
774+
end
775+
inits = sv[2]::Vector{Any}
776+
if !isempty(inits)
777+
unlock(require_lock) # temporarily _unlock_ during these callbacks
778+
try
779+
ccall(:jl_init_restored_modules, Cvoid, (Any,), inits)
780+
finally
781+
lock(require_lock)
766782
end
767783
end
768-
isassigned(sv, 2) && ccall(:jl_init_restored_modules, Cvoid, (Any,), sv[2])
769784
return restored
770785
end
771786

@@ -862,7 +877,7 @@ function _require_search_from_serialized(pkg::PkgId, sourcepath::String)
862877
end
863878

864879
# to synchronize multiple tasks trying to import/using something
865-
const package_locks = Dict{PkgId,Condition}()
880+
const package_locks = Dict{PkgId,Threads.Condition}()
866881

867882
# to notify downstream consumers that a module was successfully loaded
868883
# Callbacks take the form (mod::Base.PkgId) -> nothing.
@@ -885,7 +900,9 @@ function _include_dependency(mod::Module, _path::AbstractString)
885900
path = normpath(joinpath(dirname(prev), _path))
886901
end
887902
if _track_dependencies[]
903+
@lock require_lock begin
888904
push!(_require_dependencies, (mod, path, mtime(path)))
905+
end
889906
end
890907
return path, prev
891908
end
@@ -957,6 +974,7 @@ For more details regarding code loading, see the manual sections on [modules](@r
957974
[parallel computing](@ref code-availability).
958975
"""
959976
function require(into::Module, mod::Symbol)
977+
@lock require_lock begin
960978
LOADING_CACHE[] = LoadingCache()
961979
try
962980
uuidkey = identify_package(into, String(mod))
@@ -998,6 +1016,7 @@ function require(into::Module, mod::Symbol)
9981016
finally
9991017
LOADING_CACHE[] = nothing
10001018
end
1019+
end
10011020
end
10021021

10031022
mutable struct PkgOrigin
@@ -1009,6 +1028,7 @@ PkgOrigin() = PkgOrigin(nothing, nothing)
10091028
const pkgorigins = Dict{PkgId,PkgOrigin}()
10101029

10111030
function require(uuidkey::PkgId)
1031+
@lock require_lock begin
10121032
if !root_module_exists(uuidkey)
10131033
cachefile = _require(uuidkey)
10141034
if cachefile !== nothing
@@ -1020,15 +1040,19 @@ function require(uuidkey::PkgId)
10201040
end
10211041
end
10221042
return root_module(uuidkey)
1043+
end
10231044
end
10241045

10251046
const loaded_modules = Dict{PkgId,Module}()
10261047
const module_keys = IdDict{Module,PkgId}() # the reverse
10271048

1028-
is_root_module(m::Module) = haskey(module_keys, m)
1029-
root_module_key(m::Module) = module_keys[m]
1049+
is_root_module(m::Module) = @lock require_lock haskey(module_keys, m)
1050+
root_module_key(m::Module) = @lock require_lock module_keys[m]
10301051

10311052
function register_root_module(m::Module)
1053+
# n.b. This is called from C after creating a new module in `Base.__toplevel__`,
1054+
# instead of adding them to the binding table there.
1055+
@lock require_lock begin
10321056
key = PkgId(m, String(nameof(m)))
10331057
if haskey(loaded_modules, key)
10341058
oldm = loaded_modules[key]
@@ -1038,6 +1062,7 @@ function register_root_module(m::Module)
10381062
end
10391063
loaded_modules[key] = m
10401064
module_keys[m] = key
1065+
end
10411066
nothing
10421067
end
10431068

@@ -1053,12 +1078,13 @@ using Base
10531078
end
10541079

10551080
# get a top-level Module from the given key
1056-
root_module(key::PkgId) = loaded_modules[key]
1081+
root_module(key::PkgId) = @lock require_lock loaded_modules[key]
10571082
root_module(where::Module, name::Symbol) =
10581083
root_module(identify_package(where, String(name)))
1084+
maybe_root_module(key::PkgId) = @lock require_lock get(loaded_modules, key, nothing)
10591085

1060-
root_module_exists(key::PkgId) = haskey(loaded_modules, key)
1061-
loaded_modules_array() = collect(values(loaded_modules))
1086+
root_module_exists(key::PkgId) = @lock require_lock haskey(loaded_modules, key)
1087+
loaded_modules_array() = @lock require_lock collect(values(loaded_modules))
10621088

10631089
function unreference_module(key::PkgId)
10641090
if haskey(loaded_modules, key)
@@ -1077,7 +1103,7 @@ function _require(pkg::PkgId)
10771103
wait(loading)
10781104
return
10791105
end
1080-
package_locks[pkg] = Condition()
1106+
package_locks[pkg] = Threads.Condition(require_lock)
10811107

10821108
last = toplevel_load[]
10831109
try
@@ -1145,10 +1171,12 @@ function _require(pkg::PkgId)
11451171
if uuid !== old_uuid
11461172
ccall(:jl_set_module_uuid, Cvoid, (Any, NTuple{2, UInt64}), __toplevel__, uuid)
11471173
end
1174+
unlock(require_lock)
11481175
try
11491176
include(__toplevel__, path)
11501177
return
11511178
finally
1179+
lock(require_lock)
11521180
if uuid !== old_uuid
11531181
ccall(:jl_set_module_uuid, Cvoid, (Any, NTuple{2, UInt64}), __toplevel__, old_uuid)
11541182
end

base/toml_parser.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function Parser(str::String; filepath=nothing)
104104
IdSet{TOMLDict}(), # defined_tables
105105
root,
106106
filepath,
107-
isdefined(Base, :loaded_modules) ? get(Base.loaded_modules, DATES_PKGID, nothing) : nothing,
107+
isdefined(Base, :maybe_root_module) ? Base.maybe_root_module(DATES_PKGID) : nothing,
108108
)
109109
startup(l)
110110
return l

src/dump.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,8 +2203,6 @@ static jl_array_t *jl_finalize_deserializer(jl_serializer_state *s, arraylist_t
22032203

22042204
JL_DLLEXPORT void jl_init_restored_modules(jl_array_t *init_order)
22052205
{
2206-
if (!init_order)
2207-
return;
22082206
int i, l = jl_array_len(init_order);
22092207
for (i = 0; i < l; i++) {
22102208
jl_value_t *mod = jl_array_ptr_ref(init_order, i);
@@ -2657,6 +2655,9 @@ static jl_value_t *_jl_restore_incremental(ios_t *f, jl_array_t *mod_array)
26572655
jl_recache_other(); // make all of the other objects identities correct (needs to be after insert methods)
26582656
htable_free(&uniquing_table);
26592657
jl_array_t *init_order = jl_finalize_deserializer(&s, tracee_list); // done with f and s (needs to be after recache)
2658+
if (init_order == NULL)
2659+
init_order = (jl_array_t*)jl_an_empty_vec_any;
2660+
assert(jl_isa((jl_value_t*)init_order, jl_array_any_type));
26602661

26612662
JL_GC_PUSH4(&init_order, &restored, &external_backedges, &external_edges);
26622663
jl_gc_enable(en); // subtyping can allocate a lot, not valid before recache-other

test/threads_exec.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,3 +912,31 @@ end
912912
@test reproducible_rand(r, 10) == val
913913
end
914914
end
915+
916+
# issue #41546, thread-safe package loading
917+
@testset "package loading" begin
918+
ch = Channel{Bool}(nthreads())
919+
barrier = Base.Event()
920+
old_act_proj = Base.ACTIVE_PROJECT[]
921+
try
922+
pushfirst!(LOAD_PATH, "@")
923+
Base.ACTIVE_PROJECT[] = joinpath(@__DIR__, "TestPkg")
924+
@sync begin
925+
for _ in 1:nthreads()
926+
Threads.@spawn begin
927+
put!(ch, true)
928+
wait(barrier)
929+
@eval using TestPkg
930+
end
931+
end
932+
for _ in 1:nthreads()
933+
take!(ch)
934+
end
935+
notify(barrier)
936+
end
937+
@test Base.root_module(@__MODULE__, :TestPkg) isa Module
938+
finally
939+
Base.ACTIVE_PROJECT[] = old_act_proj
940+
popfirst!(LOAD_PATH)
941+
end
942+
end

0 commit comments

Comments
 (0)