Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Oct 17, 2024
1 parent d791eed commit ca17c46
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 63 deletions.
16 changes: 10 additions & 6 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <AMReX_Config.H>

#include <AMReX_MultiFab.H>
#include <AMReX_FFT_Helper.H>
#include <numeric>
#include <tuple>
#include <utility>
Expand All @@ -26,9 +27,6 @@
namespace amrex::FFT
{

enum struct Scaling { full, symmetric, none };
enum struct Direction { forward, backward };

template <typename T = Real>
class R2C
{
Expand All @@ -37,7 +35,7 @@ public:
MultiFab, FabArray<BaseFab<T> > >;
using cMF = FabArray<BaseFab<GpuComplex<T> > >;

R2C (Box const& domain);
explicit R2C (Box const& domain, Info const& info = Info{});

~R2C ();

Expand Down Expand Up @@ -146,6 +144,8 @@ private:
static void destroy_plan (P plan);
static std::pair<FFTPlan2,FFTPlan2> make_c2c_plans (cMF& inout);

Info m_info;

Box m_real_domain;
Box m_spectral_domain_x;
Box m_spectral_domain_y;
Expand Down Expand Up @@ -183,8 +183,9 @@ private:
};

template <typename T>
R2C<T>::R2C (Box const& domain)
: m_real_domain(domain),
R2C<T>::R2C (Box const& domain, Info const& info)
: m_info(info),
m_real_domain(domain),
m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2,
domain.bigEnd(1),
domain.bigEnd(2)))),
Expand All @@ -197,6 +198,9 @@ R2C<T>::R2C (Box const& domain)
{
static_assert(std::is_same_v<float,T> || std::is_same_v<double,T>);
AMREX_ALWAYS_ASSERT(m_real_domain.smallEnd() == 0 && m_real_domain.cellCentered());
#if (AMREX_SPACEDIM != 3)
AMREX_ALWAYS_ASSERT(false == m_info.batch_mode);
#endif

int myproc = ParallelDescriptor::MyProc();
int nprocs = ParallelDescriptor::NProcs();
Expand Down
73 changes: 73 additions & 0 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#ifndef AMREX_FFT_HELPER_H_
#define AMREX_FFT_HELPER_H_
#include <AMReX_Config.H>

#include <AMReX.H>
#include <AMReX_Geometry.H>
#include <AMReX_Gpu.H>
#include <AMReX_GpuComplex.H>
#include <AMReX_Math.H>

namespace amrex::FFT
{

enum struct Scaling { full, symmetric, none };
enum struct Direction { forward, backward };

struct Info
{
//! Supported only in 3D. When batch_mode is true, FFT is performed on
//! the first two dimensions only and the third dimension size is the
//! batch size.
bool batch_mode = false;

Info& setBatchMode (bool x) { batch_mode = x; return *this; }
};

template <typename T>
struct PoissonSpectral
{
PoissonSpectral (Geometry const& geom)
: fac({AMREX_D_DECL(T(2)*Math::pi<T>()/T(geom.ProbLength(0)),
T(2)*Math::pi<T>()/T(geom.ProbLength(1)),
T(2)*Math::pi<T>()/T(geom.ProbLength(2)))}),
dx({AMREX_D_DECL(T(geom.CellSize(0)),
T(geom.CellSize(1)),
T(geom.CellSize(2)))}),
scale(T(1.0/geom.Domain().d_numPts())),
len(geom.Domain().length())
{
static_assert(std::is_floating_point_v<T>);
}

AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (int i, int j, int k, GpuComplex<T>& spectral_data) const
{
amrex::ignore_unused(i,j,k);
// the values in the upper-half of the spectral array in y and z
// are here interpreted as negative wavenumbers
AMREX_D_TERM(T a = fac[0]*i;,
T b = (j < len[1]/2) ? fac[1]*j : fac[1]*(len[1]-j);,
T c = (k < len[2]/2) ? fac[2]*k : fac[2]*(len[2]-k));
T k2 = AMREX_D_TERM(T(2)*(std::cos(a*dx[0])-T(1))/(dx[0]*dx[0]),
+T(2)*(std::cos(b*dx[1])-T(1))/(dx[1]*dx[1]),
+T(2)*(std::cos(c*dx[2])-T(1))/(dx[2]*dx[2]));
if (k2 != T(0)) {
spectral_data /= k2;
} else {
// interpretation here is that the average value of the
// solution is zero
spectral_data = 0;
}
spectral_data *= scale;
};

GpuArray<T,AMREX_SPACEDIM> fac;
GpuArray<T,AMREX_SPACEDIM> dx;
T scale;
IntVect len;
};

}

#endif
1 change: 1 addition & 0 deletions Src/FFT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ foreach(D IN LISTS AMReX_SPACEDIM)
PRIVATE
AMReX_FFT.H
AMReX_FFT.cpp
AMReX_FFT_Helper.H
)

endforeach()
2 changes: 1 addition & 1 deletion Src/FFT/Make.package
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ifndef AMREX_FFT_MAKE
AMREX_FFT_MAKE := 1

CEXE_headers += AMReX_FFT.H
CEXE_headers += AMReX_FFT.H AMReX_FFT_Helper.H
CEXE_sources += AMReX_FFT.cpp

VPATH_LOCATIONS += $(AMREX_HOME)/Src/FFT
Expand Down
2 changes: 1 addition & 1 deletion Tests/FFT/Poisson/GNUmakefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
AMREX_HOME := ../../..

DEBUG = TRUE
DEBUG = FALSE

DIM = 3

Expand Down
98 changes: 43 additions & 55 deletions Tests/FFT/Poisson/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ int main (int argc, char* argv[])
{
BL_PROFILE("main");

int n_cell_x = 64;
int n_cell_y = 64;
int n_cell_z = 64;
AMREX_D_TERM(int n_cell_x = 64;,
int n_cell_y = 64;,
int n_cell_z = 64);

Real prob_lo_x = 0.;
Real prob_lo_y = 0.;
Real prob_lo_z = 0.;
Real prob_hi_x = 1.;
Real prob_hi_y = 1.;
Real prob_hi_z = 1.;
AMREX_D_TERM(Real prob_lo_x = 0.;,
Real prob_lo_y = 0.;,
Real prob_lo_z = 0.);
AMREX_D_TERM(Real prob_hi_x = 1.;,
Real prob_hi_y = 1.;,
Real prob_hi_z = 1.);

{
ParmParse pp;
pp.query("n_cell_x", n_cell_x);
pp.query("n_cell_y", n_cell_y);
pp.query("n_cell_z", n_cell_z);
AMREX_D_TERM(pp.query("n_cell_x", n_cell_x);,
pp.query("n_cell_y", n_cell_y);,
pp.query("n_cell_z", n_cell_z));
}

Box domain(IntVect(0),IntVect(AMREX_D_DECL(n_cell_x-1,n_cell_y-1,n_cell_z-1)));
Expand All @@ -49,52 +49,34 @@ int main (int argc, char* argv[])
auto const& rhsma = rhs.arrays();
ParallelFor(rhs, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
{
Real x = (i+0.5_rt) * dx[0];
Real y = (AMREX_SPACEDIM>=2) ? (j+0.5_rt) * dx[1] : 0._rt;
Real z = (AMREX_SPACEDIM==3) ? (k+0.5_rt) * dx[2] : 0._rt;
rhsma[b](i,j,k) = std::exp(-10._rt*((x-0.5_rt)*(x-0.5_rt)*1.05_rt +
(y-0.5_rt)*(y-0.5_rt)*0.90_rt +
(z-0.5_rt)*(z-0.5_rt)));
AMREX_D_TERM(Real x = (i+0.5_rt) * dx[0] - 0.5_rt;,
Real y = (j+0.5_rt) * dx[1] - 0.5_rt;,
Real z = (k+0.5_rt) * dx[2] - 0.5_rt);
rhsma[b](i,j,k) = std::exp(-10._rt*
(AMREX_D_TERM(x*x*1.05_rt, + y*y*0.90_rt, + z*z)));
});

// Shift rhs so that its sum is zero.
auto rhosum = rhs.sum(0);
rhs.plus(-rhosum/geom.Domain().d_numPts(), 0, 1);

Real facx = 2._rt*Math::pi<Real>()/std::abs(prob_hi_x-prob_lo_x);
Real facy = 2._rt*Math::pi<Real>()/std::abs(prob_hi_y-prob_lo_y);
Real facz = 2._rt*Math::pi<Real>()/std::abs(prob_hi_z-prob_lo_z);
Real scale = 1._rt/(Real(n_cell_z)*Real(n_cell_y)*Real(n_cell_z));

auto post_forward = [=] AMREX_GPU_DEVICE (int i, int j, int k,
GpuComplex<Real>& spectral_data)
{
amrex::ignore_unused(j,k);
// the values in the upper-half of the spectral array in y and z
// are here interpreted as negative wavenumbers
AMREX_D_TERM(Real a = facx*i;,
Real b = (j < n_cell_y/2) ? facy*j : facy*(n_cell_y-j);,
Real c = (k < n_cell_z/2) ? facz*k : facz*(n_cell_z-k));
Real k2 = AMREX_D_TERM(2._rt*(std::cos(a*dx[0])-1._rt)/(dx[0]*dx[0]),
+2._rt*(std::cos(b*dx[1])-1._rt)/(dx[1]*dx[1]),
+2._rt*(std::cos(c*dx[2])-1._rt)/(dx[2]*dx[2]));
if (k2 != 0._rt) {
spectral_data /= k2;
} else {
// interpretation here is that the average value of the
// solution is zero
spectral_data *= 0._rt;
}
spectral_data *= scale;
};

auto t0 = amrex::second();

FFT::R2C fft(geom.Domain());
fft.forwardThenBackward(rhs, soln, post_forward);
FFT::PoissonSpectral<Real> post_forward(geom);

auto t1 = amrex::second();
amrex::Print() << " AMReX FFT time: " << t1-t0 << "\n";

double tsolve;
for (int n = 0; n < 2; ++n) {
auto ta = amrex::second();
fft.forwardThenBackward(rhs, soln, post_forward);
auto tb = amrex::second();
tsolve = tb-ta;
}

amrex::Print() << " AMReX FFT setup time: " << t1-t0 << ", solve time "
<< tsolve << "\n";

{
MultiFab phi(soln.boxArray(), soln.DistributionMap(), 1, 1);
Expand All @@ -107,16 +89,22 @@ int main (int argc, char* argv[])
ParallelFor(res, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
{
auto const& phia = phi_ma[b];
auto lap = (phia(i-1,j,k)-2.*phia(i,j,k)+phia(i+1,j,k)) / (dx[0]*dx[0])
+ (phia(i,j-1,k)-2.*phia(i,j,k)+phia(i,j+1,k)) / (dx[1]*dx[1])
+ (phia(i,j,k-1)-2.*phia(i,j,k)+phia(i,j,k+1)) / (dx[2]*dx[2]);
auto lap = AMREX_D_TERM
(((phia(i-1,j,k)-2.*phia(i,j,k)+phia(i+1,j,k)) / (dx[0]*dx[0])),
+ ((phia(i,j-1,k)-2.*phia(i,j,k)+phia(i,j+1,k)) / (dx[1]*dx[1])),
+ ((phia(i,j,k-1)-2.*phia(i,j,k)+phia(i,j,k+1)) / (dx[2]*dx[2])));
res_ma[b](i,j,k) = rhs_ma[b](i,j,k) - lap;
});
amrex::Print() << " rhs.min & max: " << rhs.min(0) << " " << rhs.max(0) << "\n"
<< " res.min & max: " << res.min(0) << " " << res.max(0) << "\n";
VisMF::Write(soln, "soln");
VisMF::Write(rhs, "rhs");
VisMF::Write(res, "res");
auto bnorm = rhs.norminf();
auto rnorm = res.norminf();
amrex::Print() << " rhs inf norm " << bnorm << "\n"
<< " res inf norm " << rnorm << "\n";
#ifdef AMREX_USE_FLOAT
auto eps = 1.e-3f;
#else
auto eps = 1.e-11;
#endif
AMREX_ALWAYS_ASSERT(rnorm < eps*bnorm);
}
}
amrex::Finalize();
Expand Down

0 comments on commit ca17c46

Please sign in to comment.