Skip to content

[SYCL][Reduction] Limit reduction work non-uniform case #8458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 81 additions & 79 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,10 +1093,9 @@ static inline size_t GreatestPowerOfTwo(size_t N) {
return Ret;
}

template <typename BarrierTy, typename FuncTy>
void doTreeReductionHelper(size_t WorkSize, size_t LID, BarrierTy Barrier,
FuncTy Func) {
Barrier();
template <typename FuncTy>
void doTreeReductionHelper(size_t WorkSize, size_t LID, FuncTy Func) {
workGroupBarrier();

// Initial pivot is the greatest power-of-two value smaller or equal to the
// work size.
Expand All @@ -1113,25 +1112,69 @@ void doTreeReductionHelper(size_t WorkSize, size_t LID, BarrierTy Barrier,
if (Pivot != WorkSize) {
if (Pivot + LID < WorkSize)
Func(LID, Pivot + LID);
Barrier();
workGroupBarrier();
}

// Now the amount of work must be power-of-two, so do the tree reduction.
for (size_t CurPivot = Pivot >> 1; CurPivot > 0; CurPivot >>= 1) {
if (LID < CurPivot)
Func(LID, CurPivot + LID);
Barrier();
workGroupBarrier();
}
}

template <typename LocalRedsTy, typename BinOpTy, typename BarrierTy>
void doTreeReduction(size_t WorkSize, size_t LID, LocalRedsTy &LocalReds,
BinOpTy &BOp, BarrierTy Barrier) {
doTreeReductionHelper(WorkSize, LID, Barrier, [&](size_t I, size_t J) {
// Enum for specifying work size guarantees in tree-reduction.
enum class WorkSizeGuarantees { None, Equal, LessOrEqual };

template <WorkSizeGuarantees WSGuarantee, int Dim, typename LocalRedsTy,
typename BinOpTy, typename AccessFuncTy>
void doTreeReduction(size_t WorkSize, nd_item<Dim> NDIt, LocalRedsTy &LocalReds,
BinOpTy &BOp, AccessFuncTy AccessFunc) {
size_t LID = NDIt.get_local_linear_id();
size_t AdjustedWorkSize;
if constexpr (WSGuarantee == WorkSizeGuarantees::LessOrEqual ||
WSGuarantee == WorkSizeGuarantees::Equal) {
// If there is less-or-equal number of items and amount of work, we just
// load the work into the local memory and start reducing. If we know it is
// equal we can let the optimizer remove the check.
if (WSGuarantee == WorkSizeGuarantees::Equal || LID < WorkSize)
LocalReds[LID] = AccessFunc(LID);
AdjustedWorkSize = WorkSize;
} else {
// Otherwise we have no guarantee and we need to first reduce the amount of
// work to fit into the local memory.
size_t WGSize = NDIt.get_local_range().size();
AdjustedWorkSize = std::min(WorkSize, WGSize);
if (LID < AdjustedWorkSize) {
auto LocalSum = AccessFunc(LID);
for (size_t I = LID + WGSize; I < WorkSize; I += WGSize)
LocalSum = BOp(LocalSum, AccessFunc(I));

LocalReds[LID] = LocalSum;
}
}
doTreeReductionHelper(AdjustedWorkSize, LID, [&](size_t I, size_t J) {
LocalReds[I] = BOp(LocalReds[I], LocalReds[J]);
});
}

// Tree-reduction over tuples of accessors. This assumes that WorkSize is
// less than or equal to the work-group size.
// TODO: For variadics/tuples we don't provide such a high-level abstraction as
// for the scalar case above. Is there some C++ magic to level them?
template <typename... LocalAccT, typename... BOPsT, size_t... Is>
void doTreeReductionOnTuple(size_t WorkSize, size_t LID,
ReduTupleT<LocalAccT...> &LocalAccs,
ReduTupleT<BOPsT...> &BOPs,
std::index_sequence<Is...> ReduIndices) {
doTreeReductionHelper(WorkSize, LID, [&](size_t I, size_t J) {
auto ProcessOne = [=](auto &LocalAcc, auto &BOp) {
LocalAcc[I] = BOp(LocalAcc[I], LocalAcc[J]);
};
(ProcessOne(std::get<Is>(LocalAccs), std::get<Is>(BOPs)), ...);
});
}

template <> struct NDRangeReduction<reduction::strategy::range_basic> {
template <typename KernelName, int Dims, typename PropertiesT,
typename KernelType, typename Reduction>
Expand Down Expand Up @@ -1170,11 +1213,9 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
size_t LID = NDId.get_local_linear_id();
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

doTreeReduction(WGSize, LID, LocalReds, BOp,
[&]() { workGroupBarrier(); });
doTreeReduction<WorkSizeGuarantees::Equal>(
WGSize, NDId, LocalReds, BOp,
[&](size_t) { return getReducerAccess(Reducer).getElement(E); });

if (LID == 0) {
auto V = LocalReds[0];
Expand All @@ -1200,14 +1241,9 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
// Reduce each result separately
// TODO: Opportunity to parallelize across elements
for (int E = 0; E < NElements; ++E) {
auto LocalSum = Identity;
for (size_t I = LID; I < NWorkGroups; I += WGSize)
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);

LocalReds[LID] = LocalSum;

doTreeReduction(WGSize, LID, LocalReds, BOp,
[&]() { workGroupBarrier(); });
doTreeReduction<WorkSizeGuarantees::None>(
NWorkGroups, NDId, LocalReds, BOp,
[&](size_t I) { return PartialSums[I * NElements + E]; });
if (LID == 0) {
auto V = LocalReds[0];
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1288,12 +1324,10 @@ struct NDRangeReduction<
// This prevents local memory from scaling with elements
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

typename Reduction::binary_operation BOp;
doTreeReduction(WGSize, LID, LocalReds, BOp,
[&]() { NDIt.barrier(); });
doTreeReduction<WorkSizeGuarantees::Equal>(
WGSize, NDIt, LocalReds, BOp,
[&](size_t) { return getReducerAccess(Reducer).getElement(E); });

if (LID == 0)
getReducerAccess(Reducer).getElement(E) = LocalReds[0];
Expand Down Expand Up @@ -1494,10 +1528,9 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
// This prevents local memory from scaling with elements
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

doTreeReduction(WGSize, LID, LocalReds, BOp, [&]() { NDIt.barrier(); });
doTreeReduction<WorkSizeGuarantees::Equal>(
WGSize, NDIt, LocalReds, BOp,
[&](size_t) { return getReducerAccess(Reducer).getElement(E); });

// Compute the partial sum/reduction for the work-group.
if (LID == 0) {
Expand Down Expand Up @@ -1569,20 +1602,19 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
size_t WGSize = NDIt.get_local_range().size();
size_t LID = NDIt.get_local_linear_id();
size_t GID = NDIt.get_global_linear_id();
size_t GrID = NDIt.get_group_linear_id();

for (int E = 0; E < NElements; ++E) {
// Copy the element to local memory to prepare it for
// tree-reduction.
LocalReds[LID] = (UniformPow2WG || GID < NWorkItems)
? In[GID * NElements + E]
: ReduIdentity;
// The last work-group may not have enough work for all its items.
size_t RemainingWorkSize =
sycl::min(WGSize, NWorkItems - GrID * WGSize);

doTreeReduction(WGSize, LID, LocalReds, BOp,
[&]() { NDIt.barrier(); });
doTreeReduction<WorkSizeGuarantees::LessOrEqual>(
RemainingWorkSize, NDIt, LocalReds, BOp,
[&](size_t) { return In[GID * NElements + E]; });

// Compute the partial sum/reduction for the work-group.
if (LID == 0) {
size_t GrID = NDIt.get_group_linear_id();
typename Reduction::result_type PSum = LocalReds[0];
if (IsUpdateOfUserVar)
PSum = BOp(Out[0], PSum);
Expand Down Expand Up @@ -1621,29 +1653,6 @@ auto createReduOutAccs(size_t NWorkGroups, handler &CGH,
CGH)...);
}

template <typename... LocalAccT, typename... BOPsT, size_t... Is>
void reduceReduLocalAccs(size_t IndexA, size_t IndexB,
ReduTupleT<LocalAccT...> LocalAccs,
ReduTupleT<BOPsT...> BOPs,
std::index_sequence<Is...>) {
auto ProcessOne = [=](auto &LocalAcc, auto &BOp) {
LocalAcc[IndexA] = BOp(LocalAcc[IndexA], LocalAcc[IndexB]);
};
(ProcessOne(std::get<Is>(LocalAccs), std::get<Is>(BOPs)), ...);
}

template <typename... LocalAccT, typename... BOPsT, size_t... Is,
typename BarrierTy>
void doTreeReduction(size_t WorkSize, size_t LID,
ReduTupleT<LocalAccT...> &LocalAccs,
ReduTupleT<BOPsT...> &BOPs,
std::index_sequence<Is...> ReduIndices,
BarrierTy Barrier) {
doTreeReductionHelper(WorkSize, LID, Barrier, [&](size_t I, size_t J) {
reduceReduLocalAccs(I, J, LocalAccs, BOPs, ReduIndices);
});
}

template <typename... Reductions, typename... OutAccT, typename... LocalAccT,
typename... BOPsT, typename... Ts, size_t... Is>
void writeReduSumsToOutAccs(
Expand Down Expand Up @@ -1766,8 +1775,7 @@ void reduCGFuncImplScalar(
getReducerAccess(std::get<Is>(ReducersTuple)).getElement(0)),
...);

doTreeReduction(WGSize, LID, LocalAccsTuple, BOPsTuple, ReduIndices,
[&]() { NDIt.barrier(); });
doTreeReductionOnTuple(WGSize, LID, LocalAccsTuple, BOPsTuple, ReduIndices);

// Compute the partial sum/reduction for the work-group.
if (LID == 0) {
Expand All @@ -1792,11 +1800,9 @@ void reduCGFuncImplArrayHelper(bool IsOneWG, nd_item<Dims> NDIt,
// This prevents local memory from scaling with elements
auto NElements = Reduction::num_elements;
for (size_t E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

doTreeReduction(WGSize, LID, LocalReds, BOp, [&]() { NDIt.barrier(); });
doTreeReduction<WorkSizeGuarantees::Equal>(
WGSize, NDIt, LocalReds, BOp,
[&](size_t) { return getReducerAccess(Reducer).getElement(E); });

// Add the initial value of user's variable to the final result.
if (LID == 0) {
Expand Down Expand Up @@ -1940,8 +1946,8 @@ void reduAuxCGFuncImplScalar(
if (LID < RemainingWorkSize)
((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(InAccsTuple)[GID]), ...);

doTreeReduction(RemainingWorkSize, LID, LocalAccsTuple, BOPsTuple,
ReduIndices, [&]() { NDIt.barrier(); });
doTreeReductionOnTuple(RemainingWorkSize, LID, LocalAccsTuple, BOPsTuple,
ReduIndices);

// Compute the partial sum/reduction for the work-group.
if (LID == 0) {
Expand All @@ -1964,13 +1970,9 @@ void reduAuxCGFuncImplArrayHelper(bool UniformPow2WG, bool IsOneWG,
// This prevents local memory from scaling with elements
auto NElements = Reduction::num_elements;
for (size_t E = 0; E < NElements; ++E) {
// The end work-group may have less work than the rest, so we only need to
// read the value of the elements that still have work left.
if (LID < RemainingWorkSize)
LocalReds[LID] = In[GID * NElements + E];

doTreeReduction(RemainingWorkSize, LID, LocalReds, BOp,
[&]() { NDIt.barrier(); });
doTreeReduction<WorkSizeGuarantees::LessOrEqual>(
RemainingWorkSize, NDIt, LocalReds, BOp,
[&](size_t) { return In[GID * NElements + E]; });

// Add the initial value of user's variable to the final result.
if (LID == 0) {
Expand Down