Skip to content

Commit

Permalink
AMReX_FLATTEN_FOR (AMReX-Codes#3855)
Browse files Browse the repository at this point in the history
A new build option to force flattening of ParallelFor and similar
functions for host device. The default is yes unless it's a debug build.
The CMake option is AMReX_FLATTEN_FOR, whereas for GNU make, it's
USE_FLATTEN_FOR.
  • Loading branch information
WeiqunZhang authored Mar 28, 2024
1 parent 21a7a66 commit 8d57ebc
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/gcc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ jobs:
cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_VERBOSE_MAKEFILE=ON \
-DAMReX_FLATTEN_FOR=OFF \
-DAMReX_ASSERTIONS=ON \
-DAMReX_TESTING=ON \
-DAMReX_BOUND_CHECK=ON \
Expand Down
3 changes: 3 additions & 0 deletions Docs/sphinx_documentation/source/BuildingAMReX.rst
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ The list of available options is reported in the :ref:`table <tab:cmakevar>` bel
+------------------------------+-------------------------------------------------+-------------------------+-----------------------+
| AMReX_PROBINIT | Enable support for probin file | Platform dependent | YES, NO |
+------------------------------+-------------------------------------------------+-------------------------+-----------------------+
| AMReX_FLATTEN_FOR | Enable flattening of ParallelFor and similar | YES unless for Debug | YES, NO |
| | functions for host code | build | |
+------------------------------+-------------------------------------------------+-------------------------+-----------------------+
.. raw:: latex

\end{center}
Expand Down
6 changes: 6 additions & 0 deletions Src/Base/AMReX_Extension.H
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@
#define AMREX_FLATTEN
#endif

#ifdef AMREX_USE_FLATTEN_FOR
#define AMREX_ATTRIBUTE_FLATTEN_FOR AMREX_FLATTEN
#else
#define AMREX_ATTRIBUTE_FLATTEN_FOR
#endif

// unroll loop
#define AMREX_TO_STRING_HELPER(X) #X
#define AMREX_TO_STRING(X) AMREX_TO_STRING_HELPER(X)
Expand Down
9 changes: 9 additions & 0 deletions Src/Base/AMReX_GpuLaunchFunctsC.H
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void launch (T const& n, L&& f) noexcept
}

template <typename T, typename L, typename M=std::enable_if_t<std::is_integral_v<T>> >
AMREX_ATTRIBUTE_FLATTEN_FOR
void For (T n, L const& f) noexcept
{
for (T i = 0; i < n; ++i) {
Expand Down Expand Up @@ -96,6 +97,7 @@ void For (Gpu::KernelInfo const&, T n, L&& f) noexcept
}

template <typename T, typename L, typename M=std::enable_if_t<std::is_integral_v<T>> >
AMREX_ATTRIBUTE_FLATTEN_FOR
void ParallelFor (T n, L const& f) noexcept
{
AMREX_PRAGMA_SIMD
Expand Down Expand Up @@ -125,6 +127,7 @@ void ParallelFor (Gpu::KernelInfo const&, T n, L&& f) noexcept
}

template <typename L>
AMREX_ATTRIBUTE_FLATTEN_FOR
void For (Box const& box, L const& f) noexcept
{
const auto lo = amrex::lbound(box);
Expand Down Expand Up @@ -157,6 +160,7 @@ void For (Gpu::KernelInfo const&, Box const& box, L&& f) noexcept
}

template <typename L>
AMREX_ATTRIBUTE_FLATTEN_FOR
void ParallelFor (Box const& box, L const& f) noexcept
{
const auto lo = amrex::lbound(box);
Expand Down Expand Up @@ -190,6 +194,7 @@ void ParallelFor (Gpu::KernelInfo const&, Box const& box, L&& f) noexcept
}

template <typename T, typename L, typename M=std::enable_if_t<std::is_integral_v<T>> >
AMREX_ATTRIBUTE_FLATTEN_FOR
void For (Box const& box, T ncomp, L const& f) noexcept
{
const auto lo = amrex::lbound(box);
Expand Down Expand Up @@ -224,6 +229,7 @@ void For (Gpu::KernelInfo const&, Box const& box, T ncomp, L&& f) noexcept
}

template <typename T, typename L, typename M=std::enable_if_t<std::is_integral_v<T>> >
AMREX_ATTRIBUTE_FLATTEN_FOR
void ParallelFor (Box const& box, T ncomp, L const& f) noexcept
{
const auto lo = amrex::lbound(box);
Expand Down Expand Up @@ -1037,6 +1043,7 @@ void HostDeviceFor (Gpu::KernelInfo const&,
}

template <typename T, typename L, typename M=std::enable_if_t<std::is_integral_v<T>> >
AMREX_ATTRIBUTE_FLATTEN_FOR
void ParallelForRNG (T n, L const& f) noexcept
{
for (T i = 0; i < n; ++i) {
Expand All @@ -1045,6 +1052,7 @@ void ParallelForRNG (T n, L const& f) noexcept
}

template <typename L>
AMREX_ATTRIBUTE_FLATTEN_FOR
void ParallelForRNG (Box const& box, L const& f) noexcept
{
const auto lo = amrex::lbound(box);
Expand All @@ -1057,6 +1065,7 @@ void ParallelForRNG (Box const& box, L const& f) noexcept
}

template <typename T, typename L, typename M=std::enable_if_t<std::is_integral_v<T>> >
AMREX_ATTRIBUTE_FLATTEN_FOR
void ParallelForRNG (Box const& box, T ncomp, L const& f) noexcept
{
const auto lo = amrex::lbound(box);
Expand Down
17 changes: 17 additions & 0 deletions Src/Base/AMReX_Loop.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#include <AMReX_Config.H>

#include <AMReX_Box.H>
#include <AMReX_Extension.H>

namespace amrex {

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void Loop (Dim3 lo, Dim3 hi, F const& f) noexcept
{
for (int k = lo.z; k <= hi.z; ++k) {
Expand All @@ -19,6 +21,7 @@ void Loop (Dim3 lo, Dim3 hi, F const& f) noexcept

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void Loop (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept
{
for (int n = 0; n < ncomp; ++n) {
Expand All @@ -31,6 +34,7 @@ void Loop (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrent (Dim3 lo, Dim3 hi, F const& f) noexcept
{
for (int k = lo.z; k <= hi.z; ++k) {
Expand All @@ -43,6 +47,7 @@ void LoopConcurrent (Dim3 lo, Dim3 hi, F const& f) noexcept

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrent (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept
{
for (int n = 0; n < ncomp; ++n) {
Expand All @@ -56,6 +61,7 @@ void LoopConcurrent (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void Loop (Box const& bx, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand All @@ -69,6 +75,7 @@ void Loop (Box const& bx, F const& f) noexcept

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void Loop (Box const& bx, int ncomp, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand All @@ -83,6 +90,7 @@ void Loop (Box const& bx, int ncomp, F const& f) noexcept

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrent (Box const& bx, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand All @@ -97,6 +105,7 @@ void LoopConcurrent (Box const& bx, F const& f) noexcept

template <class F>
AMREX_GPU_HOST_DEVICE
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrent (Box const& bx, int ncomp, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand All @@ -116,6 +125,7 @@ void LoopConcurrent (Box const& bx, int ncomp, F const& f) noexcept
// of the warning, we have to use the functions below for those situations.

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopOnCpu (Dim3 lo, Dim3 hi, F const& f) noexcept
{
for (int k = lo.z; k <= hi.z; ++k) {
Expand All @@ -126,6 +136,7 @@ void LoopOnCpu (Dim3 lo, Dim3 hi, F const& f) noexcept
}

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopOnCpu (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept
{
for (int n = 0; n < ncomp; ++n) {
Expand All @@ -137,6 +148,7 @@ void LoopOnCpu (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept
}

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrentOnCpu (Dim3 lo, Dim3 hi, F const& f) noexcept
{
for (int k = lo.z; k <= hi.z; ++k) {
Expand All @@ -148,6 +160,7 @@ void LoopConcurrentOnCpu (Dim3 lo, Dim3 hi, F const& f) noexcept
}

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrentOnCpu (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept
{
for (int n = 0; n < ncomp; ++n) {
Expand All @@ -160,6 +173,7 @@ void LoopConcurrentOnCpu (Dim3 lo, Dim3 hi, int ncomp, F const& f) noexcept
}

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopOnCpu (Box const& bx, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand All @@ -172,6 +186,7 @@ void LoopOnCpu (Box const& bx, F const& f) noexcept
}

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopOnCpu (Box const& bx, int ncomp, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand All @@ -185,6 +200,7 @@ void LoopOnCpu (Box const& bx, int ncomp, F const& f) noexcept
}

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrentOnCpu (Box const& bx, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand All @@ -198,6 +214,7 @@ void LoopConcurrentOnCpu (Box const& bx, F const& f) noexcept
}

template <class F>
AMREX_ATTRIBUTE_FLATTEN_FOR
void LoopConcurrentOnCpu (Box const& bx, int ncomp, F const& f) noexcept
{
const auto lo = amrex::lbound(bx);
Expand Down
2 changes: 2 additions & 0 deletions Tools/CMake/AMReXConfig.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ set(AMReX_HDF5_ZFP_FOUND @AMReX_HDF5_ZFP@)
set(AMReX_FPE_FOUND @AMReX_FPE@)
set(AMReX_PIC_FOUND @AMReX_PIC@)
set(AMReX_ASSERTIONS_FOUND @AMReX_ASSERTIONS@)
set(AMReX_FLATTEN_FOR_FOUND @AMReX_FLATTEN_FOR@)

# Profiling options
set(AMReX_BASEP_FOUND @AMReX_BASE_PROFILE@)
Expand Down Expand Up @@ -141,6 +142,7 @@ set(AMReX_HDF5_ZFP @AMReX_HDF5_ZFP@)
set(AMReX_FPE @AMReX_FPE@)
set(AMReX_PIC @AMReX_PIC@)
set(AMReX_ASSERTIONS @AMReX_ASSERTIONS@)
set(AMReX_FLATTEN_FOR @AMReX_FLATTEN_FOR@)

# Profiling options
set(AMReX_BASE_PROFILE @AMReX_BASE_PROFILE@)
Expand Down
7 changes: 7 additions & 0 deletions Tools/CMake/AMReXOptions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,13 @@ endif ()

print_option( AMReX_ASSERTIONS )

if ( "${CMAKE_BUILD_TYPE}" MATCHES "Debug" )
option( AMReX_FLATTEN_FOR "Enable flattening of ParallelFor and other similar functions" OFF)
else ()
option( AMReX_FLATTEN_FOR "Enable flattening of ParallelFor and other similar functions" ON)
endif ()
print_option( AMReX_FLATTEN_FOR )

option(AMReX_BOUND_CHECK "Enable bound checking in Array4 class" OFF)
print_option( AMReX_BOUND_CHECK )

Expand Down
3 changes: 3 additions & 0 deletions Tools/CMake/AMReXSetDefines.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ add_amrex_define( AMREX_${CMAKE_SYSTEM_NAME} )
# Assertions
add_amrex_define( AMREX_USE_ASSERTION NO_LEGACY IF AMReX_ASSERTIONS )

# Flatten
add_amrex_define( AMREX_USE_FLATTEN_FOR NO_LEGACY IF AMReX_FLATTEN_FOR )

# Bound checking
add_amrex_define( AMREX_BOUND_CHECK NO_LEGACY IF AMReX_BOUND_CHECK )

Expand Down
1 change: 1 addition & 0 deletions Tools/CMake/AMReX_Config_ND.H.in
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#cmakedefine BL_SPACEDIM @D@
#endif
#cmakedefine AMREX_USE_ASSERTION
#cmakedefine AMREX_USE_FLATTEN_FOR
#cmakedefine AMREX_BOUND_CHECK
#cmakedefine AMREX_EXPORT_DYNAMIC
#cmakedefine BL_FORT_USE_UNDERSCORE
Expand Down
12 changes: 12 additions & 0 deletions Tools/GNUMake/Make.defs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ else
USE_FORCE_INLINE := FALSE
endif

ifdef USE_FLATTEN_FOR
USE_FLATTEN_FOR := $(strip $(USE_FLATTEN_FOR))
else
USE_FLATTEN_FOR := TRUE
endif

ifdef WARN_ALL
WARN_ALL := $(strip $(WARN_ALL))
else
Expand Down Expand Up @@ -714,6 +720,12 @@ ifeq ($(USE_FORCE_INLINE),TRUE)
CPPFLAGS += -DAMREX_USE_FORCE_INLINE
endif

ifeq ($(USE_FLATTEN_FOR),TRUE)
ifneq ($(DEBUG),TRUE)
CPPFLAGS += -DAMREX_USE_FLATTEN_FOR
endif
endif

ifeq ($(USE_ACC),TRUE)

USE_GPU := TRUE
Expand Down

0 comments on commit 8d57ebc

Please sign in to comment.