Skip to content

Commit

Permalink
TagParallelFor: Add assertion against integer overflow (AMReX-Codes#3790
Browse files Browse the repository at this point in the history
)
  • Loading branch information
WeiqunZhang authored Mar 7, 2024
1 parent 8f4127a commit 944d4b4
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions Src/Base/AMReX_TagParallelFor.H
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <AMReX_Box.H>
#include <AMReX_GpuLaunch.H>
#include <AMReX_Vector.H>
#include <limits>
#include <utility>

namespace amrex {
Expand Down Expand Up @@ -74,31 +75,31 @@ struct Array4BoxValTag {
template <class T>
struct VectorTag {
T* p;
int m_size;
Long m_size;

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
int size () const noexcept { return m_size; }
Long size () const noexcept { return m_size; }
};

#ifdef AMREX_USE_GPU

namespace detail {

template <typename T>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<T>().box())>, Box>::value,
int>
Long>
get_tag_size (T const& tag) noexcept
{
AMREX_ASSERT(tag.box().numPts() < Long(std::numeric_limits<int>::max()));
return static_cast<int>(tag.box().numPts());
}

template <typename T>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
std::enable_if_t<std::is_integral<std::decay_t<decltype(std::declval<T>().size())> >::value,
int>
Long>
get_tag_size (T const& tag) noexcept
{
AMREX_ASSERT(tag.size() < Long(std::numeric_limits<int>::max()));
return tag.size();
}

Expand Down Expand Up @@ -151,15 +152,17 @@ ParallelFor_doit (Vector<TagType> const& tags, F && f)
const int ntags = tags.size();
if (ntags == 0) { return; }

Long l_ntotwarps = 0;
int ntotwarps = 0;
Vector<int> nwarps;
nwarps.reserve(ntags+1);
for (int i = 0; i < ntags; ++i)
{
auto& tag = tags[i];
nwarps.push_back(ntotwarps);
ntotwarps += static_cast<int>((get_tag_size(tag)
+ Gpu::Device::warp_size-1) / Gpu::Device::warp_size);
auto nw = (get_tag_size(tag) + Gpu::Device::warp_size-1) / Gpu::Device::warp_size;
l_ntotwarps += nw;
ntotwarps += static_cast<int>(nw);
}
nwarps.push_back(ntotwarps);

Expand All @@ -182,6 +185,9 @@ ParallelFor_doit (Vector<TagType> const& tags, F && f)
constexpr int nwarps_per_block = nthreads/Gpu::Device::warp_size;
int nblocks = (ntotwarps + nwarps_per_block-1) / nwarps_per_block;

amrex::ignore_unused(l_ntotwarps);
AMREX_ASSERT(l_ntotwarps+nwarps_per_block-1 < Long(std::numeric_limits<int>::max()));

amrex::launch(nblocks, nthreads, Gpu::gpuStream(),
#ifdef AMREX_USE_SYCL
[=] AMREX_GPU_DEVICE (sycl::nd_item<1> const& item) noexcept
Expand All @@ -192,11 +198,11 @@ ParallelFor_doit (Vector<TagType> const& tags, F && f)
#endif
{
#ifdef AMREX_USE_SYCL
int g_tid = item.get_global_id(0);
std::size_t g_tid = item.get_global_id(0);
#else
int g_tid = blockDim.x*blockIdx.x + threadIdx.x;
auto g_tid = std::size_t(blockDim.x)*blockIdx.x + threadIdx.x;
#endif
int g_wid = g_tid / Gpu::Device::warp_size;
auto g_wid = int(g_tid / Gpu::Device::warp_size);
if (g_wid >= ntotwarps) { return; }

int tag_id = amrex::bisect(d_nwarps, 0, ntags, g_wid);
Expand Down

0 comments on commit 944d4b4

Please sign in to comment.