Skip to content

Commit

Permalink
Fix 3D loop2; implement 3D indexing.
Browse files Browse the repository at this point in the history
  • Loading branch information
johncbowman committed Oct 19, 2022
1 parent e2d0e44 commit ab2a45f
Showing 1 changed file with 46 additions and 27 deletions.
73 changes: 46 additions & 27 deletions tests/convolve.h
Original file line number Diff line number Diff line change
Expand Up @@ -1284,34 +1284,34 @@ class Convolution2 : public ThreadBase {
}
}

void forward(Complex **f, Complex **F, unsigned int r,
void forward(Complex **f, Complex **F, unsigned int rx,
unsigned int start, unsigned int stop,
unsigned int offset=0) {
for(unsigned int a=start; a < stop; ++a)
(fftx->*Forward)(f[a]+offset,F[a],r,W);
(fftx->*Forward)(f[a]+offset,F[a],rx,W);
}

void subconvolution(Complex **F, unsigned int C,
unsigned int stride, unsigned int r,
unsigned int stride, unsigned int rx,
unsigned int offset=0) {
unsigned int D=r == 0 ? fftx->D0 : fftx->D;
PARALLEL(
for(unsigned int i=0; i < C; ++i) {
unsigned int t=ThreadBase::get_thread_num0();
Convolution *cy=convolvey[t];
for(unsigned int d=0; d < D; ++d) {
unsigned int t=ThreadBase::get_thread_num0();
Convolution *cy=convolvey[t];
cy->indices.index[0]=fftx->index(r+d,i);
cy->indices.index[0]=fftx->index(rx+d,i);
cy->convolveRaw(F,offset+(D*i+d)*stride,&cy->indices);
}
}
);
}

void backward(Complex **F, Complex **f, unsigned int r,
void backward(Complex **F, Complex **f, unsigned int rx,
unsigned int start, unsigned int stop,
unsigned int offset=0, Complex *W0=NULL) {
for(unsigned int b=0; b < B; ++b)
(fftx->*Backward)(F[b],f[b]+offset,r,W0);
for(unsigned int b=start; b < stop; ++b)
(fftx->*Backward)(F[b],f[b]+offset,rx,W0);
if(W && W == W0) (fftx->*Pad)(W0);
}

Expand Down Expand Up @@ -1468,6 +1468,8 @@ class Convolution3 : public ThreadBase {
unsigned int nloops;
bool overwrite;
public:
Indices indices;

Convolution3(unsigned int threads=fftw::maxthreads) :
ThreadBase(threads), fftz(NULL), convolvez(NULL), convolveyz(NULL),
W(NULL), allocateW(false), loop2(false) {}
Expand Down Expand Up @@ -1671,17 +1673,25 @@ class Convolution3 : public ThreadBase {
}

void subconvolution(Complex **F, unsigned int C,
unsigned int stride, unsigned int offset=0) {
unsigned int stride, unsigned int rx,
unsigned int offset=0) {
unsigned int D=r == 0 ? fftx->D0 : fftx->D;
PARALLEL(
for(unsigned int i=0; i < C; ++i)
convolveyz[ThreadBase::get_thread_num0()]->
convolveRaw(F,offset+i*stride);
for(unsigned int i=0; i < C; ++i) {
unsigned int t=ThreadBase::get_thread_num0();
Convolution2 *cyz=convolveyz[t];
for(unsigned int d=0; d < D; ++d) {
cyz->indices.index[1]=fftx->index(rx+d,i);
cyz->convolveRaw(F,offset+(D*i+d)*stride,&cyz->indices);
}
}
);
}

void backward(Complex **F, Complex **f, unsigned int rx,
unsigned int start, unsigned int stop,
unsigned int offset=0, Complex *W0=NULL) {
for(unsigned int b=0; b < B; ++b) {
for(unsigned int b=start; b < stop; ++b) {
if(Sy == Lz)
(fftx->*Backward)(F[b],f[b]+offset,rx,W0);
else {
Expand Down Expand Up @@ -1712,21 +1722,30 @@ class Convolution3 : public ThreadBase {

// f is a pointer to A distinct data blocks each of size Lx*Sx,
// shifted by offset.
void convolveRaw(Complex **f, unsigned int offset=0) {
void convolveRaw(Complex **f, unsigned int offset=0, Indices *indices=NULL) {
for(unsigned int t=0; t < threads; ++t)
convolveyz[t]->indices.copy(indices,2);

if(overwrite) {
forward(f,F,0,0,A,offset);
subconvolution(f,(fftx->n-1)*lx,Sx,offset);
subconvolution(F,lx,Sx);
backward(F,f,0,offset,W);
unsigned int final=fftx->n-1;
for(unsigned int r=0; r < final; ++r)
subconvolution(f,lx,Sx,r,offset+Sx*r*lx);
subconvolution(F,lx,Sx,final);
backward(F,f,0,0,B,offset,W);
} else {
if(loop2) { // FIXME
if(loop2) {
forward(f,F,0,0,A,offset);
subconvolution(F,fftx->D0*lx,Sx);
forward(f,Fp,r,0,B,offset);
backward(F,f,0,offset,W0);
forward(f,Fp,r,B,A,offset);
subconvolution(Fp,fftx->D*lx,Sx);
backward(Fp,f,r,offset,W0);
subconvolution(F,lx,Sx,0);
unsigned int C=A-B;
unsigned int a=0;
for(; a+C <= B; a += C) {
forward(f,Fp,r,a,a+C,offset);
backward(F,f,0,a,a+C,offset,W0);
}
forward(f,Fp,r,a,A,offset);
subconvolution(Fp,lx,Sx,r);
backward(Fp,f,r,0,B,offset,W0);
} else {
unsigned int Offset;
Complex **h0;
Expand All @@ -1741,8 +1760,8 @@ class Convolution3 : public ThreadBase {

for(unsigned int rx=0; rx < Rx; rx += fftx->increment(rx)) {
forward(f,F,rx,0,A,offset);
subconvolution(F,(rx == 0 ? fftx->D0 : fftx->D)*lx,Sx);
backward(F,h0,rx,Offset,W);
subconvolution(F,lx,Sx,rx);
backward(F,h0,rx,0,B,Offset,W);
}

if(nloops > 1) {
Expand Down

0 comments on commit ab2a45f

Please sign in to comment.