Skip to content

Commit

Permalink
export threads_set_callback function properly, support in openmp as w…
Browse files Browse the repository at this point in the history
…ell as threads backend
  • Loading branch information
stevengj committed Sep 16, 2019
1 parent ffd28fd commit 8a9236c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 14 deletions.
5 changes: 5 additions & 0 deletions api/fftw3.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,11 @@ FFTW_EXTERN void \
FFTW_CDECL X(cleanup_threads)(void); \
\
FFTW_EXTERN void \
FFTW_CDECL X(threads_set_callback)( \
void (*spawnloop)(void *(*work)(void *), \
void *jobdata, size_t elsize, int njobs, void *data), void *data); \
\
FFTW_EXTERN void \
FFTW_CDECL X(make_planner_thread_safe)(void); \
\
FFTW_EXTERN int \
Expand Down
10 changes: 9 additions & 1 deletion threads/api.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ int X(init_threads)(void)
and hence the time it is configured */
plnr = X(the_planner)();
X(threads_conf_standard)(plnr);

threads_inited = 1;
}
return 1;
Expand Down Expand Up @@ -84,3 +84,11 @@ void X(make_planner_thread_safe)(void)
{
X(threads_register_planner_hooks)();
}

spawnloop_function X(spawnloop_callback) = (spawnloop_function) 0;
void *X(spawnloop_callback_data) = (void *) 0;
void X(threads_set_callback)(void (*spawnloop)(void *(*work)(void *), void *, size_t, int, void *), void *data)
{
X(spawnloop_callback) = (spawnloop_function) spawnloop;
X(spawnloop_callback_data) = data;
}
16 changes: 16 additions & 0 deletions threads/openmp.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ void X(spawn_loop)(int loopmax, int nthr, spawn_function proc, void *data)
block_size = (loopmax + nthr - 1) / nthr;
nthr = (loopmax + block_size - 1) / block_size;

if (X(spawnloop_callback)) { /* user-defined spawnloop backend */
spawn_data *sdata;
STACK_MALLOC(spawn_data *, sdata, sizeof(spawn_data) * nthr);
for (i = 0; i < nthr; ++i) {
spawn_data *d = &sdata[i];
d->max = (d->min = i * block_size) + block_size;
if (d->max > loopmax)
d->max = loopmax;
d->thr_num = i;
d->data = data;
}
X(spawnloop_callback)(proc, sdata, sizeof(spawn_data), nthr, X(spawnloop_callback_data));
STACK_FREE(sdata);
return;
}

#pragma omp parallel for private(d)
for (i = 0; i < nthr; ++i) {
d.max = (d.min = i * block_size) + block_size;
Expand Down
13 changes: 2 additions & 11 deletions threads/threads.c
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,6 @@ int X(ithreads_init)(void)
return 0; /* no error */
}

typedef void (*spawnloop_function)(spawn_function, spawn_data *, size_t, int, void *);
static spawnloop_function spawnloop_callback = (spawnloop_function) 0;
void *spawnloop_callback_data = (void *) 0;
void X(threads_set_callback)(spawnloop_function spawnloop, void *data)
{
spawnloop_callback = spawnloop;
spawnloop_callback_data = data;
}

/* Distribute a loop from 0 to loopmax-1 over nthreads threads.
proc(d) is called to execute a block of iterations from d->min
to d->max-1. d->thr_num indicate the number of the thread
Expand All @@ -424,7 +415,7 @@ void X(spawn_loop)(int loopmax, int nthr, spawn_function proc, void *data)
block_size = (loopmax + nthr - 1) / nthr;
nthr = (loopmax + block_size - 1) / block_size;

if (spawnloop_callback) { /* user-defined spawnloop backend */
if (X(spawnloop_callback)) { /* user-defined spawnloop backend */
spawn_data *sdata;
STACK_MALLOC(spawn_data *, sdata, sizeof(spawn_data) * nthr);
for (i = 0; i < nthr; ++i) {
Expand All @@ -435,7 +426,7 @@ void X(spawn_loop)(int loopmax, int nthr, spawn_function proc, void *data)
d->thr_num = i;
d->data = data;
}
spawnloop_callback(proc, sdata, sizeof(spawn_data), nthr, spawnloop_callback_data);
X(spawnloop_callback)(proc, sdata, sizeof(spawn_data), nthr, X(spawnloop_callback_data));
STACK_FREE(sdata);
}
else {
Expand Down
8 changes: 6 additions & 2 deletions threads/threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ void X(spawn_loop)(int loopmax, int nthreads,
int X(ithreads_init)(void);
void X(threads_cleanup)(void);

typedef void (*spawnloop_function)(spawn_function, spawn_data *, size_t, int, void *);
extern spawnloop_function X(spawnloop_callback);
extern void *X(spawnloop_callback_data);

/* configurations */

void X(dft_thr_vrank_geq1_register)(planner *p);
void X(rdft_thr_vrank_geq1_register)(planner *p);
void X(rdft2_thr_vrank_geq1_register)(planner *p);

ct_solver *X(mksolver_ct_threads)(size_t size, INT r, int dec,
ct_solver *X(mksolver_ct_threads)(size_t size, INT r, int dec,
ct_mkinferior mkcldw,
ct_force_vrecursion force_vrecursionp);
hc2hc_solver *X(mksolver_hc2hc_threads)(size_t size, INT r, hc2hc_mkinferior mkcldw);
Expand All @@ -52,5 +56,5 @@ void X(threads_conf_standard)(planner *p);
void X(threads_register_hooks)(void);
void X(threads_unregister_hooks)(void);
void X(threads_register_planner_hooks)(void);

#endif /* __THREADS_H__ */

0 comments on commit 8a9236c

Please sign in to comment.