Skip to content

Commit 1fdc6a6

Browse files
authored
NFC: create an actual set of functions to manipulate GC thread ids (#54984)
Also adds a bunch of integrity constraint checks to ensure we don't repeat the bug from #54645.
1 parent 6139779 commit 1fdc6a6

File tree

4 files changed

+97
-23
lines changed

4 files changed

+97
-23
lines changed

src/gc.c

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,9 +1652,11 @@ void gc_sweep_wake_all(jl_ptls_t ptls, jl_gc_padded_page_stack_t *new_gc_allocd_
16521652
if (parallel_sweep_worthwhile && !page_profile_enabled) {
16531653
jl_atomic_store(&gc_allocd_scratch, new_gc_allocd_scratch);
16541654
uv_mutex_lock(&gc_threads_lock);
1655-
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
1655+
int first = gc_first_parallel_collector_thread_id();
1656+
int last = gc_last_parallel_collector_thread_id();
1657+
for (int i = first; i <= last; i++) {
16561658
jl_ptls_t ptls2 = gc_all_tls_states[i];
1657-
assert(ptls2 != NULL); // should be a GC thread
1659+
gc_check_ptls_of_parallel_collector_thread(ptls2);
16581660
jl_atomic_fetch_add(&ptls2->gc_sweeps_requested, 1);
16591661
}
16601662
uv_cond_broadcast(&gc_threads_cond);
@@ -1666,9 +1668,11 @@ void gc_sweep_wake_all(jl_ptls_t ptls, jl_gc_padded_page_stack_t *new_gc_allocd_
16661668
// collecting a page profile.
16671669
// wait for all to leave in order to ensure that a straggler doesn't
16681670
// try to enter sweeping after we set `gc_allocd_scratch` below.
1669-
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
1671+
int first = gc_first_parallel_collector_thread_id();
1672+
int last = gc_last_parallel_collector_thread_id();
1673+
for (int i = first; i <= last; i++) {
16701674
jl_ptls_t ptls2 = gc_all_tls_states[i];
1671-
assert(ptls2 != NULL); // should be a GC thread
1675+
gc_check_ptls_of_parallel_collector_thread(ptls2);
16721676
while (jl_atomic_load_acquire(&ptls2->gc_sweeps_requested) != 0) {
16731677
jl_cpu_pause();
16741678
}
@@ -3009,19 +3013,25 @@ void gc_mark_and_steal(jl_ptls_t ptls)
30093013
// since we know chunks will likely expand into a lot
30103014
// of work for the mark loop
30113015
steal : {
3016+
int first = gc_first_parallel_collector_thread_id();
3017+
int last = gc_last_parallel_collector_thread_id();
30123018
// Try to steal chunk from random GC thread
30133019
for (int i = 0; i < 4 * jl_n_markthreads; i++) {
3014-
uint32_t v = gc_first_tid + cong(jl_n_markthreads, &ptls->rngseed);
3015-
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[v]->mark_queue;
3020+
int v = gc_random_parallel_collector_thread_id(ptls);
3021+
jl_ptls_t ptls2 = gc_all_tls_states[v];
3022+
gc_check_ptls_of_parallel_collector_thread(ptls2);
3023+
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
30163024
c = gc_chunkqueue_steal_from(mq2);
30173025
if (c.cid != GC_empty_chunk) {
30183026
gc_mark_chunk(ptls, mq, &c);
30193027
goto pop;
30203028
}
30213029
}
30223030
// Sequentially walk GC threads to try to steal chunk
3023-
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
3024-
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[i]->mark_queue;
3031+
for (int i = first; i <= last; i++) {
3032+
jl_ptls_t ptls2 = gc_all_tls_states[i];
3033+
gc_check_ptls_of_parallel_collector_thread(ptls2);
3034+
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
30253035
c = gc_chunkqueue_steal_from(mq2);
30263036
if (c.cid != GC_empty_chunk) {
30273037
gc_mark_chunk(ptls, mq, &c);
@@ -3036,15 +3046,19 @@ void gc_mark_and_steal(jl_ptls_t ptls)
30363046
}
30373047
// Try to steal pointer from random GC thread
30383048
for (int i = 0; i < 4 * jl_n_markthreads; i++) {
3039-
uint32_t v = gc_first_tid + cong(jl_n_markthreads, &ptls->rngseed);
3040-
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[v]->mark_queue;
3049+
int v = gc_random_parallel_collector_thread_id(ptls);
3050+
jl_ptls_t ptls2 = gc_all_tls_states[v];
3051+
gc_check_ptls_of_parallel_collector_thread(ptls2);
3052+
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
30413053
new_obj = gc_ptr_queue_steal_from(mq2);
30423054
if (new_obj != NULL)
30433055
goto mark;
30443056
}
30453057
// Sequentially walk GC threads to try to steal pointer
3046-
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
3047-
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[i]->mark_queue;
3058+
for (int i = first; i <= last; i++) {
3059+
jl_ptls_t ptls2 = gc_all_tls_states[i];
3060+
gc_check_ptls_of_parallel_collector_thread(ptls2);
3061+
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
30483062
new_obj = gc_ptr_queue_steal_from(mq2);
30493063
if (new_obj != NULL)
30503064
goto mark;
@@ -3103,12 +3117,13 @@ int gc_should_mark(void)
31033117
}
31043118
int tid = jl_atomic_load_relaxed(&gc_master_tid);
31053119
assert(tid != -1);
3120+
assert(gc_all_tls_states != NULL);
31063121
size_t work = gc_count_work_in_queue(gc_all_tls_states[tid]);
3107-
for (tid = gc_first_tid; tid < gc_first_tid + jl_n_markthreads; tid++) {
3108-
jl_ptls_t ptls2 = gc_all_tls_states[tid];
3109-
if (ptls2 == NULL) {
3110-
continue;
3111-
}
3122+
int first = gc_first_parallel_collector_thread_id();
3123+
int last = gc_last_parallel_collector_thread_id();
3124+
for (int i = first; i <= last; i++) {
3125+
jl_ptls_t ptls2 = gc_all_tls_states[i];
3126+
gc_check_ptls_of_parallel_collector_thread(ptls2);
31123127
work += gc_count_work_in_queue(ptls2);
31133128
}
31143129
// if there is a lot of work left, enter the mark loop
@@ -3522,7 +3537,8 @@ static int _jl_gc_collect(jl_ptls_t ptls, jl_gc_collection_t collection)
35223537
jl_ptls_t ptls_dest = ptls;
35233538
jl_gc_markqueue_t *mq_dest = mq;
35243539
if (!single_threaded_mark) {
3525-
ptls_dest = gc_all_tls_states[gc_first_tid + t_i % jl_n_markthreads];
3540+
int dest_tid = gc_ith_parallel_collector_thread_id(t_i % jl_n_markthreads);
3541+
ptls_dest = gc_all_tls_states[dest_tid];
35263542
mq_dest = &ptls_dest->mark_queue;
35273543
}
35283544
if (ptls2 != NULL) {
@@ -3787,8 +3803,9 @@ static int _jl_gc_collect(jl_ptls_t ptls, jl_gc_collection_t collection)
37873803
ptls2->heap.remset->len = 0;
37883804
}
37893805
// free empty GC state for threads that have exited
3790-
if (jl_atomic_load_relaxed(&ptls2->current_task) == NULL &&
3791-
(ptls->tid < gc_first_tid || ptls2->tid >= gc_first_tid + jl_n_gcthreads)) {
3806+
if (jl_atomic_load_relaxed(&ptls2->current_task) == NULL) {
3807+
if (gc_is_parallel_collector_thread(t_i))
3808+
continue;
37923809
jl_thread_heap_t *heap = &ptls2->heap;
37933810
if (heap->weak_refs.len == 0)
37943811
small_arraylist_free(&heap->weak_refs);

src/gc.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,54 @@ extern int gc_n_threads;
449449
extern jl_ptls_t* gc_all_tls_states;
450450
extern gc_heapstatus_t gc_heap_stats;
451451

452+
STATIC_INLINE int gc_first_parallel_collector_thread_id(void) JL_NOTSAFEPOINT
453+
{
454+
if (jl_n_markthreads == 0) {
455+
return 0;
456+
}
457+
return gc_first_tid;
458+
}
459+
460+
STATIC_INLINE int gc_last_parallel_collector_thread_id(void) JL_NOTSAFEPOINT
461+
{
462+
if (jl_n_markthreads == 0) {
463+
return -1;
464+
}
465+
return gc_first_tid + jl_n_markthreads - 1;
466+
}
467+
468+
STATIC_INLINE int gc_ith_parallel_collector_thread_id(int i) JL_NOTSAFEPOINT
469+
{
470+
assert(i >= 0 && i < jl_n_markthreads);
471+
return gc_first_tid + i;
472+
}
473+
474+
STATIC_INLINE int gc_is_parallel_collector_thread(int tid) JL_NOTSAFEPOINT
475+
{
476+
return tid >= gc_first_tid && tid <= gc_last_parallel_collector_thread_id();
477+
}
478+
479+
STATIC_INLINE int gc_random_parallel_collector_thread_id(jl_ptls_t ptls) JL_NOTSAFEPOINT
480+
{
481+
assert(jl_n_markthreads > 0);
482+
int v = gc_first_tid + (int)cong(jl_n_markthreads - 1, &ptls->rngseed);
483+
assert(v >= gc_first_tid && v <= gc_last_parallel_collector_thread_id());
484+
return v;
485+
}
486+
487+
STATIC_INLINE int gc_parallel_collector_threads_enabled(void) JL_NOTSAFEPOINT
488+
{
489+
return jl_n_markthreads > 0;
490+
}
491+
492+
STATIC_INLINE void gc_check_ptls_of_parallel_collector_thread(jl_ptls_t ptls) JL_NOTSAFEPOINT
493+
{
494+
(void)ptls;
495+
assert(gc_parallel_collector_threads_enabled());
496+
assert(ptls != NULL);
497+
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_PARALLEL_COLLECTOR_THREAD);
498+
}
499+
452500
STATIC_INLINE bigval_t *bigval_header(jl_taggedvalue_t *o) JL_NOTSAFEPOINT
453501
{
454502
return container_of(o, bigval_t, header);

src/julia_threads.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ typedef struct _jl_tls_states_t {
209209
#define JL_GC_STATE_SAFE 2
210210
// gc_state = 2 means the thread is running unmanaged code that can be
211211
// execute at the same time with the GC.
212+
#define JL_GC_PARALLEL_COLLECTOR_THREAD 3
213+
// gc_state = 3 means the thread is a parallel collector thread (i.e. never runs Julia code)
214+
#define JL_GC_CONCURRENT_COLLECTOR_THREAD 4
215+
// gc_state = 4 means the thread is a concurrent collector thread (background sweeper thread that never runs Julia code)
212216
_Atomic(int8_t) gc_state; // read from foreign threads
213217
// execution of certain certain impure
214218
// statements is prohibited from certain
@@ -340,6 +344,8 @@ void jl_sigint_safepoint(jl_ptls_t tls);
340344
STATIC_INLINE int8_t jl_gc_state_set(jl_ptls_t ptls, int8_t state,
341345
int8_t old_state)
342346
{
347+
assert(old_state != JL_GC_PARALLEL_COLLECTOR_THREAD);
348+
assert(old_state != JL_GC_CONCURRENT_COLLECTOR_THREAD);
343349
jl_atomic_store_release(&ptls->gc_state, state);
344350
if (state == JL_GC_STATE_UNSAFE || old_state == JL_GC_STATE_UNSAFE)
345351
jl_gc_safepoint_(ptls);

src/scheduler.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ void jl_parallel_gc_threadfun(void *arg)
136136
JL_GC_PROMISE_ROOTED(ct);
137137
(void)jl_atomic_fetch_add_relaxed(&nrunning, -1);
138138
// wait for all threads
139-
jl_gc_state_set(ptls, JL_GC_STATE_WAITING, JL_GC_STATE_UNSAFE);
139+
jl_gc_state_set(ptls, JL_GC_PARALLEL_COLLECTOR_THREAD, JL_GC_STATE_UNSAFE);
140140
uv_barrier_wait(targ->barrier);
141141

142142
// free the thread argument here
@@ -148,8 +148,10 @@ void jl_parallel_gc_threadfun(void *arg)
148148
uv_cond_wait(&gc_threads_cond, &gc_threads_lock);
149149
}
150150
uv_mutex_unlock(&gc_threads_lock);
151+
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_PARALLEL_COLLECTOR_THREAD);
151152
gc_mark_loop_parallel(ptls, 0);
152-
if (may_sweep(ptls)) { // not an else!
153+
if (may_sweep(ptls)) {
154+
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_PARALLEL_COLLECTOR_THREAD);
153155
gc_sweep_pool_parallel(ptls);
154156
jl_atomic_fetch_add(&ptls->gc_sweeps_requested, -1);
155157
}
@@ -170,13 +172,14 @@ void jl_concurrent_gc_threadfun(void *arg)
170172
JL_GC_PROMISE_ROOTED(ct);
171173
(void)jl_atomic_fetch_add_relaxed(&nrunning, -1);
172174
// wait for all threads
173-
jl_gc_state_set(ptls, JL_GC_STATE_WAITING, JL_GC_STATE_UNSAFE);
175+
jl_gc_state_set(ptls, JL_GC_CONCURRENT_COLLECTOR_THREAD, JL_GC_STATE_UNSAFE);
174176
uv_barrier_wait(targ->barrier);
175177

176178
// free the thread argument here
177179
free(targ);
178180

179181
while (1) {
182+
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_CONCURRENT_COLLECTOR_THREAD);
180183
uv_sem_wait(&gc_sweep_assists_needed);
181184
gc_free_pages();
182185
}

0 commit comments

Comments
 (0)