Skip to content

Commit

Permalink
[Offload][OMPX] Add the runtime support for multi-dim grid and block (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shiltian authored Dec 6, 2024
1 parent 2f4eac6 commit 92376c3
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 95 deletions.
69 changes: 38 additions & 31 deletions offload/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ struct AMDGPUKernelTy : public GenericKernelTy {
}

/// Launch the AMDGPU kernel function.
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const override;

/// Print more elaborate kernel launch info for AMDGPU
Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
KernelArgsTy &KernelArgs, uint32_t NumThreads,
uint64_t NumBlocks) const override;
KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
uint32_t NumBlocks[3]) const override;

/// Get group and private segment kernel size.
uint32_t getGroupSize() const { return GroupSize; }
Expand Down Expand Up @@ -719,7 +719,7 @@ struct AMDGPUQueueTy {
/// Push a kernel launch to the queue. The kernel launch requires an output
/// signal and can define an optional input signal (nullptr if none).
Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
uint32_t NumThreads, uint64_t NumBlocks,
uint32_t NumThreads[3], uint32_t NumBlocks[3],
uint32_t GroupSize, uint64_t StackSize,
AMDGPUSignalTy *OutputSignal,
AMDGPUSignalTy *InputSignal) {
Expand All @@ -746,14 +746,18 @@ struct AMDGPUQueueTy {
assert(Packet && "Invalid packet");

// The first 32 bits of the packet are written after the other fields
uint16_t Setup = UINT16_C(1) << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
Packet->workgroup_size_x = NumThreads;
Packet->workgroup_size_y = 1;
Packet->workgroup_size_z = 1;
uint16_t Dims = NumBlocks[2] * NumThreads[2] > 1
? 3
: 1 + (NumBlocks[1] * NumThreads[1] != 1);
uint16_t Setup = UINT16_C(Dims)
<< HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
Packet->workgroup_size_x = NumThreads[0];
Packet->workgroup_size_y = NumThreads[1];
Packet->workgroup_size_z = NumThreads[2];
Packet->reserved0 = 0;
Packet->grid_size_x = NumBlocks * NumThreads;
Packet->grid_size_y = 1;
Packet->grid_size_z = 1;
Packet->grid_size_x = NumBlocks[0] * NumThreads[0];
Packet->grid_size_y = NumBlocks[1] * NumThreads[1];
Packet->grid_size_z = NumBlocks[2] * NumThreads[2];
Packet->private_segment_size =
Kernel.usesDynamicStack() ? StackSize : Kernel.getPrivateSize();
Packet->group_segment_size = GroupSize;
Expand Down Expand Up @@ -1240,7 +1244,7 @@ struct AMDGPUStreamTy {
/// the kernel finalizes. Once the kernel is finished, the stream will release
/// the kernel args buffer to the specified memory manager.
Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
uint32_t NumThreads, uint64_t NumBlocks,
uint32_t NumThreads[3], uint32_t NumBlocks[3],
uint32_t GroupSize, uint64_t StackSize,
AMDGPUMemoryManagerTy &MemoryManager) {
if (Queue == nullptr)
Expand Down Expand Up @@ -2827,10 +2831,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);

KernelArgsTy KernelArgs = {};
if (auto Err =
AMDGPUKernel.launchImpl(*this, /*NumThread=*/1u,
/*NumBlocks=*/1ul, KernelArgs,
KernelLaunchParamsTy{}, AsyncInfoWrapper))
uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
if (auto Err = AMDGPUKernel.launchImpl(
*this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
KernelLaunchParamsTy{}, AsyncInfoWrapper))
return Err;

Error Err = Plugin::success();
Expand Down Expand Up @@ -3328,7 +3332,7 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
};

Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
uint32_t NumThreads, uint64_t NumBlocks,
uint32_t NumThreads[3], uint32_t NumBlocks[3],
KernelArgsTy &KernelArgs,
KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
Expand Down Expand Up @@ -3385,13 +3389,15 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
if (ImplArgs &&
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {
ImplArgs->BlockCountX = NumBlocks;
ImplArgs->BlockCountY = 1;
ImplArgs->BlockCountZ = 1;
ImplArgs->GroupSizeX = NumThreads;
ImplArgs->GroupSizeY = 1;
ImplArgs->GroupSizeZ = 1;
ImplArgs->GridDims = 1;
ImplArgs->BlockCountX = NumBlocks[0];
ImplArgs->BlockCountY = NumBlocks[1];
ImplArgs->BlockCountZ = NumBlocks[2];
ImplArgs->GroupSizeX = NumThreads[0];
ImplArgs->GroupSizeY = NumThreads[1];
ImplArgs->GroupSizeZ = NumThreads[2];
ImplArgs->GridDims = NumBlocks[2] * NumThreads[2] > 1
? 3
: 1 + (NumBlocks[1] * NumThreads[1] != 1);
ImplArgs->DynamicLdsSize = KernelArgs.DynCGroupMem;
}

Expand All @@ -3402,8 +3408,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,

Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
KernelArgsTy &KernelArgs,
uint32_t NumThreads,
uint64_t NumBlocks) const {
uint32_t NumThreads[3],
uint32_t NumBlocks[3]) const {
// Only do all this when the output is requested
if (!(getInfoLevel() & OMP_INFOTYPE_PLUGIN_KERNEL))
return Plugin::success();
Expand Down Expand Up @@ -3440,12 +3446,13 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
// S/VGPR Spill Count: how many S/VGPRs are spilled by the kernel
// Tripcount: loop tripcount for the kernel
INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
"#Args: %d Teams x Thrds: %4lux%4u (MaxFlatWorkGroupSize: %u) LDS "
"#Args: %d Teams x Thrds: %4ux%4u (MaxFlatWorkGroupSize: %u) LDS "
"Usage: %uB #SGPRs/VGPRs: %u/%u #SGPR/VGPR Spills: %u/%u Tripcount: "
"%lu\n",
ArgNum, NumGroups, ThreadsPerGroup, MaxFlatWorkgroupSize,
GroupSegmentSize, SGPRCount, VGPRCount, SGPRSpillCount, VGPRSpillCount,
LoopTripCount);
ArgNum, NumGroups[0] * NumGroups[1] * NumGroups[2],
ThreadsPerGroup[0] * ThreadsPerGroup[1] * ThreadsPerGroup[2],
MaxFlatWorkgroupSize, GroupSegmentSize, SGPRCount, VGPRCount,
SGPRSpillCount, VGPRSpillCount, LoopTripCount);

return Plugin::success();
}
Expand Down
15 changes: 8 additions & 7 deletions offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,9 @@ struct GenericKernelTy {
Error launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
AsyncInfoWrapperTy &AsyncInfoWrapper) const;
virtual Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
virtual Error launchImpl(GenericDeviceTy &GenericDevice,
uint32_t NumThreads[3], uint32_t NumBlocks[3],
KernelArgsTy &KernelArgs,
KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const = 0;

Expand Down Expand Up @@ -320,15 +321,15 @@ struct GenericKernelTy {

/// Prints generic kernel launch information.
Error printLaunchInfo(GenericDeviceTy &GenericDevice,
KernelArgsTy &KernelArgs, uint32_t NumThreads,
uint64_t NumBlocks) const;
KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
uint32_t NumBlocks[3]) const;

/// Prints plugin-specific kernel launch information after generic kernel
/// launch information
virtual Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
KernelArgsTy &KernelArgs,
uint32_t NumThreads,
uint64_t NumBlocks) const;
uint32_t NumThreads[3],
uint32_t NumBlocks[3]) const;

private:
/// Prepare the arguments before launching the kernel.
Expand All @@ -347,7 +348,7 @@ struct GenericKernelTy {
/// The number of threads \p NumThreads can be adjusted by this method.
/// \p IsNumThreadsFromUser is true is \p NumThreads is defined by user via
/// thread_limit clause.
uint64_t getNumBlocks(GenericDeviceTy &GenericDevice,
uint32_t getNumBlocks(GenericDeviceTy &GenericDevice,
uint32_t BlockLimitClause[3], uint64_t LoopTripCount,
uint32_t &NumThreads, bool IsNumThreadsFromUser) const;

Expand Down
48 changes: 27 additions & 21 deletions offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,20 +526,21 @@ GenericKernelTy::getKernelLaunchEnvironment(

Error GenericKernelTy::printLaunchInfo(GenericDeviceTy &GenericDevice,
KernelArgsTy &KernelArgs,
uint32_t NumThreads,
uint64_t NumBlocks) const {
uint32_t NumThreads[3],
uint32_t NumBlocks[3]) const {
INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
"Launching kernel %s with %" PRIu64
" blocks and %d threads in %s mode\n",
getName(), NumBlocks, NumThreads, getExecutionModeName());
"Launching kernel %s with [%u,%u,%u] blocks and [%u,%u,%u] threads in "
"%s mode\n",
getName(), NumBlocks[0], NumBlocks[1], NumBlocks[2], NumThreads[0],
NumThreads[1], NumThreads[2], getExecutionModeName());
return printLaunchInfoDetails(GenericDevice, KernelArgs, NumThreads,
NumBlocks);
}

Error GenericKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
KernelArgsTy &KernelArgs,
uint32_t NumThreads,
uint64_t NumBlocks) const {
uint32_t NumThreads[3],
uint32_t NumBlocks[3]) const {
return Plugin::success();
}

Expand All @@ -566,10 +567,16 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
Args, Ptrs, *KernelLaunchEnvOrErr);
}

uint32_t NumThreads = getNumThreads(GenericDevice, KernelArgs.ThreadLimit);
uint64_t NumBlocks =
getNumBlocks(GenericDevice, KernelArgs.NumTeams, KernelArgs.Tripcount,
NumThreads, KernelArgs.ThreadLimit[0] > 0);
uint32_t NumThreads[3] = {KernelArgs.ThreadLimit[0],
KernelArgs.ThreadLimit[1],
KernelArgs.ThreadLimit[2]};
uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
KernelArgs.NumTeams[2]};
if (!IsBareKernel) {
NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
}

// Record the kernel description after we modified the argument count and num
// blocks/threads.
Expand All @@ -578,7 +585,8 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
RecordReplay.saveImage(getName(), getImage());
RecordReplay.saveKernelInput(getName(), getImage());
RecordReplay.saveKernelDescr(getName(), LaunchParams, KernelArgs.NumArgs,
NumBlocks, NumThreads, KernelArgs.Tripcount);
NumBlocks[0], NumThreads[0],
KernelArgs.Tripcount);
}

if (auto Err =
Expand Down Expand Up @@ -618,11 +626,10 @@ KernelLaunchParamsTy GenericKernelTy::prepareArgs(

uint32_t GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
uint32_t ThreadLimitClause[3]) const {
assert(ThreadLimitClause[1] == 0 && ThreadLimitClause[2] == 0 &&
"Multi dimensional launch not supported yet.");
assert(!IsBareKernel && "bare kernel should not call this function");

if (IsBareKernel && ThreadLimitClause[0] > 0)
return ThreadLimitClause[0];
assert(ThreadLimitClause[1] == 1 && ThreadLimitClause[2] == 1 &&
"Multi dimensional launch not supported yet.");

if (ThreadLimitClause[0] > 0 && isGenericMode())
ThreadLimitClause[0] += GenericDevice.getWarpSize();
Expand All @@ -632,16 +639,15 @@ uint32_t GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
: PreferredNumThreads);
}

uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
uint32_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
uint32_t NumTeamsClause[3],
uint64_t LoopTripCount,
uint32_t &NumThreads,
bool IsNumThreadsFromUser) const {
assert(NumTeamsClause[1] == 0 && NumTeamsClause[2] == 0 &&
"Multi dimensional launch not supported yet.");
assert(!IsBareKernel && "bare kernel should not call this function");

if (IsBareKernel && NumTeamsClause[0] > 0)
return NumTeamsClause[0];
assert(NumTeamsClause[1] == 1 && NumTeamsClause[2] == 1 &&
"Multi dimensional launch not supported yet.");

if (NumTeamsClause[0] > 0) {
// TODO: We need to honor any value and consequently allow more than the
Expand Down
19 changes: 9 additions & 10 deletions offload/plugins-nextgen/cuda/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ struct CUDAKernelTy : public GenericKernelTy {
}

/// Launch the CUDA kernel function.
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const override;

Expand Down Expand Up @@ -1228,10 +1228,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);

KernelArgsTy KernelArgs = {};
if (auto Err =
CUDAKernel.launchImpl(*this, /*NumThread=*/1u,
/*NumBlocks=*/1ul, KernelArgs,
KernelLaunchParamsTy{}, AsyncInfoWrapper))
uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
if (auto Err = CUDAKernel.launchImpl(
*this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
KernelLaunchParamsTy{}, AsyncInfoWrapper))
return Err;

Error Err = Plugin::success();
Expand Down Expand Up @@ -1274,7 +1274,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
};

Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
uint32_t NumThreads, uint64_t NumBlocks,
uint32_t NumThreads[3], uint32_t NumBlocks[3],
KernelArgsTy &KernelArgs,
KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
Expand All @@ -1292,9 +1292,8 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
reinterpret_cast<void *>(&LaunchParams.Size),
CU_LAUNCH_PARAM_END};

CUresult Res = cuLaunchKernel(Func, NumBlocks, /*gridDimY=*/1,
/*gridDimZ=*/1, NumThreads,
/*blockDimY=*/1, /*blockDimZ=*/1,
CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
NumThreads[0], NumThreads[1], NumThreads[2],
MaxDynCGroupMem, Stream, nullptr, Config);
return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName());
}
Expand Down
4 changes: 2 additions & 2 deletions offload/plugins-nextgen/host/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ struct GenELF64KernelTy : public GenericKernelTy {
}

/// Launch the kernel using the libffi.
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const override {
// Create a vector of ffi_types, one per argument.
Expand Down
25 changes: 15 additions & 10 deletions offload/src/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,25 @@ static KernelArgsTy *upgradeKernelArgs(KernelArgsTy *KernelArgs,
LocalKernelArgs.Flags = KernelArgs->Flags;
LocalKernelArgs.DynCGroupMem = 0;
LocalKernelArgs.NumTeams[0] = NumTeams;
LocalKernelArgs.NumTeams[1] = 0;
LocalKernelArgs.NumTeams[2] = 0;
LocalKernelArgs.NumTeams[1] = 1;
LocalKernelArgs.NumTeams[2] = 1;
LocalKernelArgs.ThreadLimit[0] = ThreadLimit;
LocalKernelArgs.ThreadLimit[1] = 0;
LocalKernelArgs.ThreadLimit[2] = 0;
LocalKernelArgs.ThreadLimit[1] = 1;
LocalKernelArgs.ThreadLimit[2] = 1;
return &LocalKernelArgs;
}

// FIXME: This is a WA to "calibrate" the bad work done in the front end.
// Delete this ugly code after the front end emits proper values.
auto CorrectMultiDim = [](uint32_t(&Val)[3]) {
if (Val[1] == 0)
Val[1] = 1;
if (Val[2] == 0)
Val[2] = 1;
};
CorrectMultiDim(KernelArgs->ThreadLimit);
CorrectMultiDim(KernelArgs->NumTeams);

return KernelArgs;
}

Expand Down Expand Up @@ -320,12 +331,6 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
KernelArgs =
upgradeKernelArgs(KernelArgs, LocalKernelArgs, NumTeams, ThreadLimit);

assert(KernelArgs->NumTeams[0] == static_cast<uint32_t>(NumTeams) &&
!KernelArgs->NumTeams[1] && !KernelArgs->NumTeams[2] &&
"OpenMP interface should not use multiple dimensions");
assert(KernelArgs->ThreadLimit[0] == static_cast<uint32_t>(ThreadLimit) &&
!KernelArgs->ThreadLimit[1] && !KernelArgs->ThreadLimit[2] &&
"OpenMP interface should not use multiple dimensions");
TIMESCOPE_WITH_DETAILS_AND_IDENT(
"Runtime: target exe",
"NumTeams=" + std::to_string(NumTeams) +
Expand Down
2 changes: 0 additions & 2 deletions offload/src/omptarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,8 +1451,6 @@ int target(ident_t *Loc, DeviceTy &Device, void *HostPtr,
Loc);

#ifdef OMPT_SUPPORT
assert(KernelArgs.NumTeams[1] == 0 && KernelArgs.NumTeams[2] == 0 &&
"Multi dimensional launch not supported yet.");
/// RAII to establish tool anchors before and after kernel launch
int32_t NumTeams = KernelArgs.NumTeams[0];
// No need to guard this with OMPT_IF_BUILT
Expand Down
Loading

0 comments on commit 92376c3

Please sign in to comment.