From 8a9236c494f680e550c536eeb6aa7102ecfad2ca Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Mon, 16 Sep 2019 15:48:40 -0400 Subject: [PATCH] export threads_set_callback function properly, support in openmp as well as threads backend --- api/fftw3.h | 5 +++++ threads/api.c | 10 +++++++++- threads/openmp.c | 16 ++++++++++++++++ threads/threads.c | 13 ++----------- threads/threads.h | 8 ++++++-- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/api/fftw3.h b/api/fftw3.h index 7bd4c6e55..229c475ad 100644 --- a/api/fftw3.h +++ b/api/fftw3.h @@ -386,6 +386,11 @@ FFTW_EXTERN void \ FFTW_CDECL X(cleanup_threads)(void); \ \ FFTW_EXTERN void \ +FFTW_CDECL X(threads_set_callback)( \ + void (*spawnloop)(void *(*work)(void *), \ + void *jobdata, size_t elsize, int njobs, void *data), void *data); \ + \ +FFTW_EXTERN void \ FFTW_CDECL X(make_planner_thread_safe)(void); \ \ FFTW_EXTERN int \ diff --git a/threads/api.c b/threads/api.c index eae2cd4b7..e0da41a0c 100644 --- a/threads/api.c +++ b/threads/api.c @@ -50,7 +50,7 @@ int X(init_threads)(void) and hence the time it is configured */ plnr = X(the_planner)(); X(threads_conf_standard)(plnr); - + threads_inited = 1; } return 1; @@ -84,3 +84,11 @@ void X(make_planner_thread_safe)(void) { X(threads_register_planner_hooks)(); } + +spawnloop_function X(spawnloop_callback) = (spawnloop_function) 0; +void *X(spawnloop_callback_data) = (void *) 0; +void X(threads_set_callback)(void (*spawnloop)(void *(*work)(void *), void *, size_t, int, void *), void *data) +{ + X(spawnloop_callback) = (spawnloop_function) spawnloop; + X(spawnloop_callback_data) = data; +} diff --git a/threads/openmp.c b/threads/openmp.c index 1b384ece5..05f4f02c8 100644 --- a/threads/openmp.c +++ b/threads/openmp.c @@ -58,6 +58,22 @@ void X(spawn_loop)(int loopmax, int nthr, spawn_function proc, void *data) block_size = (loopmax + nthr - 1) / nthr; nthr = (loopmax + block_size - 1) / block_size; + if (X(spawnloop_callback)) { /* user-defined spawnloop backend */ + spawn_data *sdata; + STACK_MALLOC(spawn_data *, sdata, sizeof(spawn_data) * nthr); + for (i = 0; i < nthr; ++i) { + spawn_data *d = &sdata[i]; + d->max = (d->min = i * block_size) + block_size; + if (d->max > loopmax) + d->max = loopmax; + d->thr_num = i; + d->data = data; + } + X(spawnloop_callback)(proc, sdata, sizeof(spawn_data), nthr, X(spawnloop_callback_data)); + STACK_FREE(sdata); + return; + } + #pragma omp parallel for private(d) for (i = 0; i < nthr; ++i) { d.max = (d.min = i * block_size) + block_size; diff --git a/threads/threads.c b/threads/threads.c index c45c20782..1d1371d89 100644 --- a/threads/threads.c +++ b/threads/threads.c @@ -389,15 +389,6 @@ int X(ithreads_init)(void) return 0; /* no error */ } -typedef void (*spawnloop_function)(spawn_function, spawn_data *, size_t, int, void *); -static spawnloop_function spawnloop_callback = (spawnloop_function) 0; -void *spawnloop_callback_data = (void *) 0; -void X(threads_set_callback)(spawnloop_function spawnloop, void *data) -{ - spawnloop_callback = spawnloop; - spawnloop_callback_data = data; -} - /* Distribute a loop from 0 to loopmax-1 over nthreads threads. proc(d) is called to execute a block of iterations from d->min to d->max-1. d->thr_num indicate the number of the thread @@ -424,7 +415,7 @@ void X(spawn_loop)(int loopmax, int nthr, spawn_function proc, void *data) block_size = (loopmax + nthr - 1) / nthr; nthr = (loopmax + block_size - 1) / block_size; - if (spawnloop_callback) { /* user-defined spawnloop backend */ + if (X(spawnloop_callback)) { /* user-defined spawnloop backend */ spawn_data *sdata; STACK_MALLOC(spawn_data *, sdata, sizeof(spawn_data) * nthr); for (i = 0; i < nthr; ++i) { @@ -435,7 +426,7 @@ void X(spawn_loop)(int loopmax, int nthr, spawn_function proc, void *data) d->thr_num = i; d->data = data; } - spawnloop_callback(proc, sdata, sizeof(spawn_data), nthr, spawnloop_callback_data); + X(spawnloop_callback)(proc, sdata, sizeof(spawn_data), nthr, X(spawnloop_callback_data)); STACK_FREE(sdata); } else { diff --git a/threads/threads.h b/threads/threads.h index 8c5072d0b..e48db3fbc 100644 --- a/threads/threads.h +++ b/threads/threads.h @@ -37,13 +37,17 @@ void X(spawn_loop)(int loopmax, int nthreads, int X(ithreads_init)(void); void X(threads_cleanup)(void); +typedef void (*spawnloop_function)(spawn_function, spawn_data *, size_t, int, void *); +extern spawnloop_function X(spawnloop_callback); +extern void *X(spawnloop_callback_data); + /* configurations */ void X(dft_thr_vrank_geq1_register)(planner *p); void X(rdft_thr_vrank_geq1_register)(planner *p); void X(rdft2_thr_vrank_geq1_register)(planner *p); -ct_solver *X(mksolver_ct_threads)(size_t size, INT r, int dec, +ct_solver *X(mksolver_ct_threads)(size_t size, INT r, int dec, ct_mkinferior mkcldw, ct_force_vrecursion force_vrecursionp); hc2hc_solver *X(mksolver_hc2hc_threads)(size_t size, INT r, hc2hc_mkinferior mkcldw); @@ -52,5 +56,5 @@ void X(threads_conf_standard)(planner *p); void X(threads_register_hooks)(void); void X(threads_unregister_hooks)(void); void X(threads_register_planner_hooks)(void); - + #endif /* __THREADS_H__ */