Skip to content

Commit ce97843

Browse files
committed
add precompile support for recording fields to change
Somewhat generalizes our support for changing Ptr to C_NULL. Not particularly fast, since it is just using the builtins implementation of setfield, and delaying the actual stores, but it should suffice.
1 parent 4e7baae commit ce97843

File tree

7 files changed

+167
-17
lines changed

7 files changed

+167
-17
lines changed

base/lock.jl

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
const ThreadSynchronizer = GenericCondition{Threads.SpinLock}
44

5+
"""
6+
current_task()
7+
8+
Get the currently running [`Task`](@ref).
9+
"""
10+
current_task() = ccall(:jl_get_current_task, Ref{Task}, ())
11+
512
# Advisory reentrant lock
613
"""
714
ReentrantLock()
@@ -606,16 +613,23 @@ mutable struct PerProcess{T, F}
606613
const initializer::F
607614
const lock::ReentrantLock
608615

609-
PerProcess{T}(initializer::F) where {T, F} = new{T,F}(nothing, 0x00, true, initializer, ReentrantLock())
610-
PerProcess{T,F}(initializer::F) where {T, F} = new{T,F}(nothing, 0x00, true, initializer, ReentrantLock())
611-
PerProcess(initializer) = new{Base.promote_op(initializer), typeof(initializer)}(nothing, 0x00, true, initializer, ReentrantLock())
616+
function PerProcess{T,F}(initializer::F) where {T, F}
617+
once = new{T,F}(nothing, 0x00, true, initializer, ReentrantLock())
618+
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
619+
once, :x, nothing)
620+
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
621+
once, :state, 0x00)
622+
return once
623+
end
612624
end
625+
PerProcess{T}(initializer::F) where {T, F} = PerProcess{T, F}(initializer)
626+
PerProcess(initializer) = PerProcess{Base.promote_op(initializer), typeof(initializer)}(initializer)
613627
@inline function (once::PerProcess{T})() where T
614628
state = (@atomic :acquire once.state)
615629
if state != 0x01
616630
(@noinline function init_perprocesss(once, state)
617631
state == 0x02 && error("PerProcess initializer failed previously")
618-
Base.__precompile__(once.allow_compile_time)
632+
once.allow_compile_time || __precompile__(false)
619633
lock(once.lock)
620634
try
621635
state = @atomic :monotonic once.state
@@ -644,6 +658,8 @@ function copyto_monotonic!(dest::AtomicMemory, src)
644658
for j in eachindex(src)
645659
if isassigned(src, j)
646660
@atomic :monotonic dest[i] = src[j]
661+
#else
662+
# _unsafeindex_atomic!(dest, i, src[j], :monotonic)
647663
end
648664
i += 1
649665
end
@@ -701,10 +717,18 @@ mutable struct PerThread{T, F}
701717
@atomic ss::AtomicMemory{UInt8} # states: 0=initial, 1=hasrun, 2=error, 3==concurrent
702718
const initializer::F
703719

704-
PerThread{T}(initializer::F) where {T, F} = new{T,F}(AtomicMemory{T}(), AtomicMemory{UInt8}(), initializer)
705-
PerThread{T,F}(initializer::F) where {T, F} = new{T,F}(AtomicMemory{T}(), AtomicMemory{UInt8}(), initializer)
706-
PerThread(initializer) = (T = Base.promote_op(initializer); new{T, typeof(initializer)}(AtomicMemory{T}(), AtomicMemory{UInt8}(), initializer))
720+
function PerThread{T,F}(initializer::F) where {T, F}
721+
xs, ss = AtomicMemory{T}(), AtomicMemory{UInt8}()
722+
once = new{T,F}(xs, ss, initializer)
723+
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
724+
once, :xs, xs)
725+
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
726+
once, :ss, ss)
727+
return once
728+
end
707729
end
730+
PerThread{T}(initializer::F) where {T, F} = PerThread{T,F}(initializer)
731+
PerThread(initializer) = PerThread{Base.promote_op(initializer), typeof(initializer)}(initializer)
708732
@inline function getindex(once::PerThread, tid::Integer)
709733
tid = Int(tid)
710734
ss = @atomic :acquire once.ss

base/task.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,6 @@ macro task(ex)
143143
:(Task($thunk))
144144
end
145145

146-
"""
147-
current_task()
148-
149-
Get the currently running [`Task`](@ref).
150-
"""
151-
current_task() = ccall(:jl_get_current_task, Ref{Task}, ())
152-
153146
# task states
154147

155148
const task_state_runnable = UInt8(0)

src/builtins.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ static inline size_t get_checked_fieldindex(const char *name, jl_datatype_t *st,
10081008
else {
10091009
jl_value_t *ts[2] = {(jl_value_t*)jl_long_type, (jl_value_t*)jl_symbol_type};
10101010
jl_value_t *t = jl_type_union(ts, 2);
1011-
jl_type_error("getfield", t, arg);
1011+
jl_type_error(name, t, arg);
10121012
}
10131013
if (mutabl && jl_field_isconst(st, idx)) {
10141014
jl_errorf("%s: const field .%s of type %s cannot be changed", name,

src/gc-stock.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,6 +2739,8 @@ static void gc_mark_roots(jl_gc_markqueue_t *mq)
27392739
gc_heap_snapshot_record_gc_roots((jl_value_t*)jl_global_roots_list, "global_roots_list");
27402740
gc_try_claim_and_push(mq, jl_global_roots_keyset, NULL);
27412741
gc_heap_snapshot_record_gc_roots((jl_value_t*)jl_global_roots_keyset, "global_roots_keyset");
2742+
gc_try_claim_and_push(mq, precompile_field_replace, NULL);
2743+
gc_heap_snapshot_record_gc_roots((jl_value_t*)precompile_field_replace, "precompile_field_replace");
27422744
}
27432745

27442746
// find unmarked objects that need to be finalized from the finalizer list "list".

src/julia_internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,8 @@ extern jl_genericmemory_t *jl_global_roots_list JL_GLOBALLY_ROOTED;
855855
extern jl_genericmemory_t *jl_global_roots_keyset JL_GLOBALLY_ROOTED;
856856
JL_DLLEXPORT int jl_is_globally_rooted(jl_value_t *val JL_MAYBE_UNROOTED) JL_NOTSAFEPOINT;
857857
JL_DLLEXPORT jl_value_t *jl_as_global_root(jl_value_t *val, int insert) JL_GLOBALLY_ROOTED;
858+
extern jl_svec_t *precompile_field_replace JL_GLOBALLY_ROOTED;
859+
JL_DLLEXPORT void jl_set_precompile_field_replace(jl_value_t *val, jl_value_t *field, jl_value_t *newval) JL_GLOBALLY_ROOTED;
858860

859861
jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
860862
jl_value_t *source, jl_value_t **env, size_t nenv, int do_compile);

src/staticdata.c

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ void *native_functions; // opaque jl_native_code_desc_t blob used for fetching
491491

492492
// table of struct field addresses to rewrite during saving
493493
static htable_t field_replace;
494+
static htable_t bits_replace;
494495
static htable_t relocatable_ext_cis;
495496

496497
// array of definitions for the predefined function pointers
@@ -1605,7 +1606,23 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
16051606
write_padding(f, offset - tot);
16061607
tot = offset;
16071608
size_t fsz = jl_field_size(t, i);
1608-
if (t->name->mutabl && jl_is_cpointer_type(jl_field_type_concrete(t, i)) && *(intptr_t*)slot != -1) {
1609+
jl_value_t *replace = (jl_value_t*)ptrhash_get(&bits_replace, (void*)slot);
1610+
if (replace != HT_NOTFOUND) {
1611+
assert(t->name->mutabl && !jl_field_isptr(t, i));
1612+
jl_value_t *rty = jl_typeof(replace);
1613+
size_t sz = jl_datatype_size(rty);
1614+
ios_write(f, (const char*)replace, sz);
1615+
jl_value_t *ft = jl_field_type_concrete(t, i);
1616+
int isunion = jl_is_uniontype(ft);
1617+
unsigned nth = 0;
1618+
if (!jl_find_union_component(ft, rty, &nth))
1619+
assert(0 && "invalid field assignment to isbits union");
1620+
assert(sz <= fsz - isunion);
1621+
write_padding(f, fsz - sz - isunion);
1622+
if (isunion)
1623+
write_uint8(f, nth);
1624+
}
1625+
else if (t->name->mutabl && jl_is_cpointer_type(jl_field_type_concrete(t, i)) && *(intptr_t*)slot != -1) {
16091626
// reset Ptr fields to C_NULL (but keep MAP_FAILED / INVALID_HANDLE)
16101627
assert(!jl_field_isptr(t, i));
16111628
write_pointer(f);
@@ -2552,6 +2569,65 @@ jl_mutex_t global_roots_lock;
25522569
extern jl_mutex_t world_counter_lock;
25532570
extern size_t jl_require_world;
25542571

2572+
jl_mutex_t precompile_field_replace_lock;
2573+
jl_svec_t *precompile_field_replace JL_GLOBALLY_ROOTED;
2574+
2575+
static inline jl_value_t *get_checked_fieldindex(const char *name, jl_datatype_t *st, jl_value_t *v, jl_value_t *arg, int mutabl)
2576+
{
2577+
if (mutabl) {
2578+
if (st == jl_module_type)
2579+
jl_error("cannot assign variables in other modules");
2580+
if (!st->name->mutabl)
2581+
jl_errorf("%s: immutable struct of type %s cannot be changed", name, jl_symbol_name(st->name->name));
2582+
}
2583+
size_t idx;
2584+
if (jl_is_long(arg)) {
2585+
idx = jl_unbox_long(arg) - 1;
2586+
if (idx >= jl_datatype_nfields(st))
2587+
jl_bounds_error(v, arg);
2588+
}
2589+
else if (jl_is_symbol(arg)) {
2590+
idx = jl_field_index(st, (jl_sym_t*)arg, 1);
2591+
arg = jl_box_long(idx);
2592+
}
2593+
else {
2594+
jl_value_t *ts[2] = {(jl_value_t*)jl_long_type, (jl_value_t*)jl_symbol_type};
2595+
jl_value_t *t = jl_type_union(ts, 2);
2596+
jl_type_error(name, t, arg);
2597+
}
2598+
if (mutabl && jl_field_isconst(st, idx)) {
2599+
jl_errorf("%s: const field .%s of type %s cannot be changed", name,
2600+
jl_symbol_name((jl_sym_t*)jl_svecref(jl_field_names(st), idx)), jl_symbol_name(st->name->name));
2601+
}
2602+
return arg;
2603+
}
2604+
2605+
JL_DLLEXPORT void jl_set_precompile_field_replace(jl_value_t *val, jl_value_t *field, jl_value_t *newval)
2606+
{
2607+
if (!jl_generating_output())
2608+
return;
2609+
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(val);
2610+
jl_value_t *idx = get_checked_fieldindex("setfield!", st, val, field, 1);
2611+
JL_GC_PUSH1(&idx);
2612+
size_t idxval = jl_unbox_long(idx);
2613+
jl_value_t *ft = jl_field_type_concrete(st, idxval);
2614+
if (!jl_isa(newval, ft))
2615+
jl_type_error("setfield!", ft, newval);
2616+
JL_LOCK(&precompile_field_replace_lock);
2617+
if (precompile_field_replace == NULL) {
2618+
precompile_field_replace = jl_alloc_svec(3);
2619+
jl_svecset(precompile_field_replace, 0, jl_alloc_vec_any(0));
2620+
jl_svecset(precompile_field_replace, 1, jl_alloc_vec_any(0));
2621+
jl_svecset(precompile_field_replace, 2, jl_alloc_vec_any(0));
2622+
}
2623+
jl_array_ptr_1d_push((jl_array_t*)jl_svecref(precompile_field_replace, 0), val);
2624+
jl_array_ptr_1d_push((jl_array_t*)jl_svecref(precompile_field_replace, 1), idx);
2625+
jl_array_ptr_1d_push((jl_array_t*)jl_svecref(precompile_field_replace, 2), newval);
2626+
JL_GC_POP();
2627+
JL_UNLOCK(&precompile_field_replace_lock);
2628+
}
2629+
2630+
25552631
JL_DLLEXPORT int jl_is_globally_rooted(jl_value_t *val JL_MAYBE_UNROOTED) JL_NOTSAFEPOINT
25562632
{
25572633
if (jl_is_datatype(val)) {
@@ -2671,9 +2747,41 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
26712747
jl_array_t *ext_targets, jl_array_t *edges) JL_GC_DISABLED
26722748
{
26732749
htable_new(&field_replace, 0);
2750+
htable_new(&bits_replace, 0);
26742751
// strip metadata and IR when requested
26752752
if (jl_options.strip_metadata || jl_options.strip_ir)
26762753
jl_strip_all_codeinfos();
2754+
// prepare hash table with any fields the user wanted us to rewrite during serialization
2755+
if (precompile_field_replace) {
2756+
jl_array_t *vals = (jl_array_t*)jl_svecref(precompile_field_replace, 0);
2757+
jl_array_t *fields = (jl_array_t*)jl_svecref(precompile_field_replace, 1);
2758+
jl_array_t *newvals = (jl_array_t*)jl_svecref(precompile_field_replace, 2);
2759+
size_t i, l = jl_array_nrows(vals);
2760+
assert(jl_array_nrows(fields) == l && jl_array_nrows(newvals) == l);
2761+
for (i = 0; i < l; i++) {
2762+
jl_value_t *val = jl_array_ptr_ref(vals, i);
2763+
size_t field = jl_unbox_long(jl_array_ptr_ref(fields, i));
2764+
jl_value_t *newval = jl_array_ptr_ref(newvals, i);
2765+
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(val);
2766+
size_t offs = jl_field_offset(st, field);
2767+
char *fldaddr = (char*)val + offs;
2768+
if (jl_field_isptr(st, field)) {
2769+
record_field_change((jl_value_t**)fldaddr, newval);
2770+
}
2771+
else {
2772+
// replace the bits
2773+
ptrhash_put(&bits_replace, (void*)fldaddr, newval);
2774+
// and any pointers inside
2775+
jl_datatype_t *rty = (jl_datatype_t*)jl_typeof(newval);
2776+
const jl_datatype_layout_t *layout = rty->layout;
2777+
size_t j, np = layout->npointers;
2778+
for (j = 0; j < np; j++) {
2779+
uint32_t ptr = jl_ptr_offset(rty, j);
2780+
record_field_change((jl_value_t**)fldaddr + ptr, *(((jl_value_t**)newval) + ptr));
2781+
}
2782+
}
2783+
}
2784+
}
26772785

26782786
int en = jl_gc_enable(0);
26792787
nsym_tag = 0;
@@ -2966,6 +3074,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
29663074
arraylist_free(&gvars);
29673075
arraylist_free(&external_fns);
29683076
htable_free(&field_replace);
3077+
htable_free(&bits_replace);
29693078
htable_free(&serialization_order);
29703079
htable_free(&nullptrs);
29713080
htable_free(&symbol_table);

test/threads.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,19 @@ let once = PerProcess{Int}(() -> error("expected"))
390390
@test_throws ErrorException("PerProcess initializer failed previously") once()
391391
end
392392

393-
let once = PerThread(() -> return [nothing])
393+
let e = Base.Event(true),
394+
started = Channel{Int16}(Inf),
395+
once = PerThread() do
396+
push!(started, threadid())
397+
wait(e)
398+
return [nothing]
399+
end
394400
@test typeof(once) <: PerThread{Vector{Nothing}}
401+
notify(e)
395402
x = once()
396403
@test x === once() === fetch(@async once())
404+
@test take!(started) == threadid()
405+
@test isempty(started)
397406
tids = zeros(UInt, 50)
398407
onces = Vector{Vector{Nothing}}(undef, length(tids))
399408
for i = 1:length(tids)
@@ -420,7 +429,18 @@ let once = PerThread(() -> return [nothing])
420429
err == 0 || Base.uv_error("uv_thread_join", err)
421430
end
422431
end
432+
# let them finish in 5 batches of 10
433+
for i = 1:length(tids) ÷ 10
434+
for i = 1:10
435+
@test take!(started) != threadid()
436+
end
437+
for i = 1:10
438+
notify(e)
439+
end
440+
end
441+
@test isempty(started)
423442
waitallthreads(tids)
443+
@test isempty(started)
424444
@test length(IdSet{eltype(onces)}(onces)) == length(onces) # make sure every object is unique
425445

426446
end

0 commit comments

Comments
 (0)