Skip to content

Commit

Permalink
Kernel fusing in FabArray Comm (AMReX-Codes#2559)
Browse files Browse the repository at this point in the history
Note that amrex::Add has been moved from AMReX_FabArrayUtility.H to
AMReX_FabArray.H so that it can be used by AMReX_FabArrayCommI.H.
  • Loading branch information
WeiqunZhang authored Jan 5, 2022
1 parent dfee5b7 commit 5726371
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 83 deletions.
45 changes: 45 additions & 0 deletions Src/Base/AMReX_FabArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,51 @@ Copy (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp, in
}
}

template <class FAB,
class bar = std::enable_if_t<IsBaseFab<FAB>::value> >
void
Add (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp, int numcomp, int nghost)
{
Add(dst,src,srccomp,dstcomp,numcomp,IntVect(nghost));
}

template <class FAB,
class bar = std::enable_if_t<IsBaseFab<FAB>::value> >
void
Add (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp, int numcomp, const IntVect& nghost)
{
#ifdef AMREX_USE_GPU
if (Gpu::inLaunchRegion() && dst.isFusingCandidate()) {
auto const& dstfa = dst.arrays();
auto const& srcfa = src.const_arrays();
ParallelFor(dst, nghost, numcomp,
[=] AMREX_GPU_DEVICE (int box_no, int i, int j, int k, int n) noexcept
{
dstfa[box_no](i,j,k,n+dstcomp) += srcfa[box_no](i,j,k,n+srccomp);
});
Gpu::streamSynchronize();
} else
#endif
{
#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter mfi(dst,TilingIfNotGPU()); mfi.isValid(); ++mfi)
{
const Box& bx = mfi.growntilebox(nghost);
if (bx.ok())
{
auto const srcFab = src.array(mfi);
auto dstFab = dst.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D( bx, numcomp, i, j, k, n,
{
dstFab(i,j,k,n+dstcomp) += srcFab(i,j,k,n+srccomp);
});
}
}
}
}

template <class FAB>
class FabArray
:
Expand Down
43 changes: 6 additions & 37 deletions Src/Base/AMReX_FabArrayCommI.H
Original file line number Diff line number Diff line change
Expand Up @@ -338,31 +338,13 @@ FabArray<FAB>::ParallelCopy_nowait (const FabArray<FAB>& src,
// we're doing plus()s on cell-centered data. Don't do plus()s on
// non-cell-centered data this simplistic way.
//
#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter fai(*this,TilingIfNotGPU()); fai.isValid(); ++fai)
{
const Box& bx = fai.tilebox();

// avoid self copy or plus
if (this != &src) {
auto const sfab = src.array(fai);
auto dfab = this->array(fai);
if (op == FabArrayBase::COPY) {
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, ncomp, i, j, k, n,
{
dfab(i,j,k,dcomp+n) = sfab(i,j,k,scomp+n);
});
} else {
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, ncomp, i, j, k, n,
{
dfab(i,j,k,dcomp+n) += sfab(i,j,k,scomp+n);
});
}
if (this != &src) { // avoid self copy or plus
if (op == FabArrayBase::COPY) {
Copy(*this, src, scomp, dcomp, ncomp, IntVect(0));
} else {
Add(*this, src, scomp, dcomp, ncomp, IntVect(0));
}
}

return;
}

Expand Down Expand Up @@ -845,20 +827,7 @@ FabArray<FAB>::Redistribute (const FabArray<FAB>& src,

if (ParallelContext::NProcsSub() == 1)
{
#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter fai(*this,true); fai.isValid(); ++fai)
{
const Box& bx = fai.growntilebox(nghost);
auto const sfab = src.array(fai);
auto dfab = this->array(fai);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, ncomp, i, j, k, n,
{
dfab(i,j,k,n+dcomp) = sfab(i,j,k,n+scomp);
});
}

Copy(*this, src, scomp, dcomp, ncomp, nghost);
return;
}

Expand Down
46 changes: 0 additions & 46 deletions Src/Base/AMReX_FabArrayUtility.H
Original file line number Diff line number Diff line change
Expand Up @@ -1102,52 +1102,6 @@ printCell (FabArray<FAB> const& mf, const IntVect& cell, int comp = -1,
}
}


template <class FAB,
class bar = std::enable_if_t<IsBaseFab<FAB>::value> >
void
Add (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp, int numcomp, int nghost)
{
Add(dst,src,srccomp,dstcomp,numcomp,IntVect(nghost));
}

template <class FAB,
class bar = std::enable_if_t<IsBaseFab<FAB>::value> >
void
Add (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp, int numcomp, const IntVect& nghost)
{
#ifdef AMREX_USE_GPU
if (Gpu::inLaunchRegion() && dst.isFusingCandidate()) {
auto const& dstfa = dst.arrays();
auto const& srcfa = src.const_arrays();
ParallelFor(dst, nghost, numcomp,
[=] AMREX_GPU_DEVICE (int box_no, int i, int j, int k, int n) noexcept
{
dstfa[box_no](i,j,k,n+dstcomp) += srcfa[box_no](i,j,k,n+srccomp);
});
Gpu::streamSynchronize();
} else
#endif
{
#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter mfi(dst,TilingIfNotGPU()); mfi.isValid(); ++mfi)
{
const Box& bx = mfi.growntilebox(nghost);
if (bx.ok())
{
auto const srcFab = src.array(mfi);
auto dstFab = dst.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D( bx, numcomp, i, j, k, n,
{
dstFab(i,j,k,n+dstcomp) += srcFab(i,j,k,n+srccomp);
});
}
}
}
}

template <class FAB,
class bar = std::enable_if_t<IsBaseFab<FAB>::value> >
void
Expand Down

0 comments on commit 5726371

Please sign in to comment.