Skip to content

Commit c82e339

Browse files
committed
RFC: Add a hook for detecting task switches.
Certain libraries are configured using global or thread-local state instead of passing handles to every function. CUDA, for example, has a `cudaSetDevice` function that binds a device to the current thread for all future API calls. This is at odds with Julia's task-based concurrency, which presents an execution environment that's local to the current task (e.g., in the case of CUDA, using a different device). This PR adds a hook mechanism that can be used to detect task switches, and synchronize Julia's task-local environment with the library's global or thread-local state.
1 parent d234931 commit c82e339

File tree

4 files changed

+31
-1
lines changed

4 files changed

+31
-1
lines changed

src/gc.c

+2
Original file line numberDiff line numberDiff line change
@@ -2799,6 +2799,8 @@ static void mark_roots(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp)
27992799
gc_mark_queue_obj(gc_cache, sp, jl_an_empty_vec_any);
28002800
if (jl_module_init_order != NULL)
28012801
gc_mark_queue_obj(gc_cache, sp, jl_module_init_order);
2802+
if (jl_task_switch_hooks != NULL)
2803+
gc_mark_queue_obj(gc_cache, sp, jl_task_switch_hooks);
28022804
for (size_t i = 0; i < jl_current_modules.size; i += 2) {
28032805
if (jl_current_modules.table[i + 1] != HT_NOTFOUND) {
28042806
gc_mark_queue_obj(gc_cache, sp, jl_current_modules.table[i]);

src/julia.h

+3
Original file line numberDiff line numberDiff line change
@@ -1827,6 +1827,9 @@ JL_DLLEXPORT void JL_NORETURN jl_sig_throw(void);
18271827
JL_DLLEXPORT void JL_NORETURN jl_rethrow_other(jl_value_t *e JL_MAYBE_UNROOTED);
18281828
JL_DLLEXPORT void JL_NORETURN jl_no_exc_handler(jl_value_t *e);
18291829

1830+
typedef void *(*jl_task_switch_hook_t)(jl_task_t *t JL_PROPAGATES_ROOT);
1831+
JL_DLLEXPORT void jl_hook_task_switch(jl_task_switch_hook_t hook);
1832+
18301833
#include "locks.h" // requires jl_task_t definition
18311834

18321835
JL_DLLEXPORT void jl_enter_handler(jl_handler_t *eh);

src/julia_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ void jl_gc_track_malloced_array(jl_ptls_t ptls, jl_array_t *a) JL_NOTSAFEPOINT;
368368
void jl_gc_count_allocd(size_t sz) JL_NOTSAFEPOINT;
369369
void jl_gc_run_all_finalizers(jl_ptls_t ptls);
370370
void jl_release_task_stack(jl_ptls_t ptls, jl_task_t *task);
371+
extern jl_array_t *jl_task_switch_hooks JL_GLOBALLY_ROOTED;
371372

372373
void gc_queue_binding(jl_binding_t *bnd) JL_NOTSAFEPOINT;
373374
void gc_setmark_buf(jl_ptls_t ptls, void *buf, uint8_t, size_t) JL_NOTSAFEPOINT;

src/task.c

+25-1
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,19 @@ static jl_ptls_t NOINLINE refetch_ptls(void)
480480
return jl_get_ptls_states();
481481
}
482482

483+
jl_array_t *jl_task_switch_hooks = JL_GLOBALLY_ROOTED = NULL;
484+
JL_DLLEXPORT void jl_hook_task_switch(jl_task_switch_hook_t hook)
485+
{
486+
if (jl_task_switch_hooks == NULL) {
487+
jl_value_t *array_ptr_void_type = jl_apply_type2(
488+
(jl_value_t *)jl_array_type, (jl_value_t *)jl_voidpointer_type, jl_box_long(1));
489+
jl_task_switch_hooks = jl_alloc_array_1d(array_ptr_void_type, 0);
490+
}
491+
jl_array_grow_end(jl_task_switch_hooks, 1);
492+
((jl_task_switch_hook_t *)jl_array_data(
493+
jl_task_switch_hooks))[jl_array_len(jl_task_switch_hooks) - 1] = hook;
494+
}
495+
483496
JL_DLLEXPORT void jl_switch(void)
484497
{
485498
jl_ptls_t ptls = jl_get_ptls_states();
@@ -497,7 +510,7 @@ JL_DLLEXPORT void jl_switch(void)
497510
if (ptls->in_finalizer)
498511
jl_error("task switch not allowed from inside gc finalizer");
499512
if (ptls->in_pure_callback)
500-
jl_error("task switch not allowed from inside staged nor pure functions");
513+
jl_error("task switch not allowed from inside staged nor pure functions or callbacks");
501514
if (t->sticky && jl_atomic_load_acquire(&t->tid) == -1) {
502515
// manually yielding to a task
503516
if (jl_atomic_compare_exchange(&t->tid, -1, ptls->tid) != -1)
@@ -507,6 +520,17 @@ JL_DLLEXPORT void jl_switch(void)
507520
jl_error("cannot switch to task running on another thread");
508521
}
509522

523+
if (jl_task_switch_hooks) {
524+
int last_in = ptls->in_pure_callback;
525+
ptls->in_pure_callback = 1;
526+
for (int i = 0; i < jl_array_len(jl_task_switch_hooks); i++) {
527+
jl_task_switch_hook_t hook =
528+
((jl_task_switch_hook_t *)jl_array_data(jl_task_switch_hooks))[i];
529+
hook(t);
530+
}
531+
ptls->in_pure_callback = last_in;
532+
}
533+
510534
// Store old values on the stack and reset
511535
sig_atomic_t defer_signal = ptls->defer_signal;
512536
int8_t gc_state = jl_gc_unsafe_enter(ptls);

0 commit comments

Comments
 (0)