@@ -174,10 +174,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174
174
if (op.getHint ())
175
175
op.emitWarning (" hint clause discarded" );
176
176
};
177
- auto checkHostEval = [&todo](auto op, LogicalResult &result) {
178
- if (!op.getHostEvalVars ().empty ())
179
- result = todo (" host_eval" );
180
- };
181
177
auto checkIf = [&todo](auto op, LogicalResult &result) {
182
178
if (op.getIfExpr ())
183
179
result = todo (" if" );
@@ -224,10 +220,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
224
220
op.getReductionSyms ())
225
221
result = todo (" reduction" );
226
222
};
227
- auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
228
- if (op.getThreadLimit ())
229
- result = todo (" thread_limit" );
230
- };
231
223
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
232
224
if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
233
225
op.getTaskReductionSyms ())
@@ -289,7 +281,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
289
281
checkBare (op, result);
290
282
checkDevice (op, result);
291
283
checkHasDeviceAddr (op, result);
292
- checkHostEval (op, result);
284
+
285
+ // Host evaluated clauses are supported, except for target SPMD loop
286
+ // bounds.
287
+ for (BlockArgument arg :
288
+ cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
289
+ for (Operation *user : arg.getUsers ())
290
+ if (isa<omp::LoopNestOp>(user))
291
+ result = op.emitError (" not yet implemented: host evaluation of "
292
+ " loop bounds in omp.target operation" );
293
+
293
294
checkIf (op, result);
294
295
checkInReduction (op, result);
295
296
checkIsDevicePtr (op, result);
@@ -306,7 +307,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
306
307
result = todo (" firstprivate" );
307
308
}
308
309
}
309
- checkThreadLimit (op, result);
310
310
})
311
311
.Default ([](Operation &) {
312
312
// Assume all clauses for an operation can be translated unless they are
@@ -3889,6 +3889,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
3889
3889
return builder.saveIP ();
3890
3890
}
3891
3891
3892
+ // / Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3893
+ // / operation and populate output variables with their corresponding host value
3894
+ // / (i.e. operand evaluated outside of the target region), based on their uses
3895
+ // / inside of the target region.
3896
+ // /
3897
+ // / Loop bounds and steps are only optionally populated, if output vectors are
3898
+ // / provided.
3899
+ static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
3900
+ Value &numTeamsLower, Value &numTeamsUpper,
3901
+ Value &threadLimit) {
3902
+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3903
+ for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
3904
+ blockArgIface.getHostEvalBlockArgs ())) {
3905
+ Value hostEvalVar = std::get<0 >(item), blockArg = std::get<1 >(item);
3906
+
3907
+ for (Operation *user : blockArg.getUsers ()) {
3908
+ llvm::TypeSwitch<Operation *>(user)
3909
+ .Case ([&](omp::TeamsOp teamsOp) {
3910
+ if (teamsOp.getNumTeamsLower () == blockArg)
3911
+ numTeamsLower = hostEvalVar;
3912
+ else if (teamsOp.getNumTeamsUpper () == blockArg)
3913
+ numTeamsUpper = hostEvalVar;
3914
+ else if (teamsOp.getThreadLimit () == blockArg)
3915
+ threadLimit = hostEvalVar;
3916
+ else
3917
+ llvm_unreachable (" unsupported host_eval use" );
3918
+ })
3919
+ .Case ([&](omp::ParallelOp parallelOp) {
3920
+ if (parallelOp.getNumThreads () == blockArg)
3921
+ numThreads = hostEvalVar;
3922
+ else
3923
+ llvm_unreachable (" unsupported host_eval use" );
3924
+ })
3925
+ .Case ([&](omp::LoopNestOp loopOp) {
3926
+ // TODO: Extract bounds and step values.
3927
+ })
3928
+ .Default ([](Operation *) {
3929
+ llvm_unreachable (" unsupported host_eval use" );
3930
+ });
3931
+ }
3932
+ }
3933
+ }
3934
+
3935
+ // / If \p op is of the given type parameter, return it casted to that type.
3936
+ // / Otherwise, if its immediate parent operation (or some other higher-level
3937
+ // / parent, if \p immediateParent is false) is of that type, return that parent
3938
+ // / casted to the given type.
3939
+ // /
3940
+ // / If \p op is \c null or neither it or its parent(s) are of the specified
3941
+ // / type, return a \c null operation.
3942
+ template <typename OpTy>
3943
+ static OpTy castOrGetParentOfType (Operation *op, bool immediateParent = false ) {
3944
+ if (!op)
3945
+ return OpTy ();
3946
+
3947
+ if (OpTy casted = dyn_cast<OpTy>(op))
3948
+ return casted;
3949
+
3950
+ if (immediateParent)
3951
+ return dyn_cast_if_present<OpTy>(op->getParentOp ());
3952
+
3953
+ return op->getParentOfType <OpTy>();
3954
+ }
3955
+
3956
+ // / Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
3957
+ // / values as stated by the corresponding clauses, if constant.
3958
+ // /
3959
+ // / These default values must be set before the creation of the outlined LLVM
3960
+ // / function for the target region, so that they can be used to initialize the
3961
+ // / corresponding global `ConfigurationEnvironmentTy` structure.
3962
+ static void
3963
+ initTargetDefaultAttrs (omp::TargetOp targetOp,
3964
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
3965
+ bool isTargetDevice) {
3966
+ // TODO: Handle constant 'if' clauses.
3967
+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp ();
3968
+
3969
+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
3970
+ if (!isTargetDevice) {
3971
+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3972
+ threadLimit);
3973
+ } else {
3974
+ // In the target device, values for these clauses are not passed as
3975
+ // host_eval, but instead evaluated prior to entry to the region. This
3976
+ // ensures values are mapped and available inside of the target region.
3977
+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3978
+ numTeamsLower = teamsOp.getNumTeamsLower ();
3979
+ numTeamsUpper = teamsOp.getNumTeamsUpper ();
3980
+ threadLimit = teamsOp.getThreadLimit ();
3981
+ }
3982
+
3983
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3984
+ numThreads = parallelOp.getNumThreads ();
3985
+ }
3986
+
3987
+ auto extractConstInteger = [](Value value) -> std::optional<int64_t > {
3988
+ if (auto constOp =
3989
+ dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp ()))
3990
+ if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
3991
+ return constAttr.getInt ();
3992
+
3993
+ return std::nullopt;
3994
+ };
3995
+
3996
+ // Handle clauses impacting the number of teams.
3997
+
3998
+ int32_t minTeamsVal = 1 , maxTeamsVal = -1 ;
3999
+ if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4000
+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
4001
+ // clang and set min and max to the same value.
4002
+ if (numTeamsUpper) {
4003
+ if (auto val = extractConstInteger (numTeamsUpper))
4004
+ minTeamsVal = maxTeamsVal = *val;
4005
+ } else {
4006
+ minTeamsVal = maxTeamsVal = 0 ;
4007
+ }
4008
+ } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
4009
+ /* immediateParent=*/ true ) ||
4010
+ castOrGetParentOfType<omp::SimdOp>(capturedOp,
4011
+ /* immediateParent=*/ true )) {
4012
+ minTeamsVal = maxTeamsVal = 1 ;
4013
+ } else {
4014
+ minTeamsVal = maxTeamsVal = -1 ;
4015
+ }
4016
+
4017
+ // Handle clauses impacting the number of threads.
4018
+
4019
+ auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
4020
+ int32_t &result) {
4021
+ if (!clauseValue)
4022
+ return ;
4023
+
4024
+ if (auto val = extractConstInteger (clauseValue))
4025
+ result = *val;
4026
+
4027
+ // Found an applicable clause, so it's not undefined. Mark as unknown
4028
+ // because it's not constant.
4029
+ if (result < 0 )
4030
+ result = 0 ;
4031
+ };
4032
+
4033
+ // Extract 'thread_limit' clause from 'target' and 'teams' directives.
4034
+ int32_t targetThreadLimitVal = -1 , teamsThreadLimitVal = -1 ;
4035
+ setMaxValueFromClause (targetOp.getThreadLimit (), targetThreadLimitVal);
4036
+ setMaxValueFromClause (threadLimit, teamsThreadLimitVal);
4037
+
4038
+ // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
4039
+ int32_t maxThreadsVal = -1 ;
4040
+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4041
+ setMaxValueFromClause (numThreads, maxThreadsVal);
4042
+ else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
4043
+ /* immediateParent=*/ true ))
4044
+ maxThreadsVal = 1 ;
4045
+
4046
+ // For max values, < 0 means unset, == 0 means set but unknown. Select the
4047
+ // minimum value between 'max_threads' and 'thread_limit' clauses that were
4048
+ // set.
4049
+ int32_t combinedMaxThreadsVal = targetThreadLimitVal;
4050
+ if (combinedMaxThreadsVal < 0 ||
4051
+ (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
4052
+ combinedMaxThreadsVal = teamsThreadLimitVal;
4053
+
4054
+ if (combinedMaxThreadsVal < 0 ||
4055
+ (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
4056
+ combinedMaxThreadsVal = maxThreadsVal;
4057
+
4058
+ // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4059
+ attrs.MinTeams = minTeamsVal;
4060
+ attrs.MaxTeams .front () = maxTeamsVal;
4061
+ attrs.MinThreads = 1 ;
4062
+ attrs.MaxThreads .front () = combinedMaxThreadsVal;
4063
+ }
4064
+
4065
+ // / Gather LLVM runtime values for all clauses evaluated in the host that are
4066
+ // / passed to the kernel invocation.
4067
+ // /
4068
+ // / This function must be called only when compiling for the host. Also, it will
4069
+ // / only provide correct results if it's called after the body of \c targetOp
4070
+ // / has been fully generated.
4071
+ static void
4072
+ initTargetRuntimeAttrs (llvm::IRBuilderBase &builder,
4073
+ LLVM::ModuleTranslation &moduleTranslation,
4074
+ omp::TargetOp targetOp,
4075
+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4076
+ Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4077
+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4078
+ teamsThreadLimit);
4079
+
4080
+ // TODO: Handle constant 'if' clauses.
4081
+ if (Value targetThreadLimit = targetOp.getThreadLimit ())
4082
+ attrs.TargetThreadLimit .front () =
4083
+ moduleTranslation.lookupValue (targetThreadLimit);
4084
+
4085
+ if (numTeamsLower)
4086
+ attrs.MinTeams = moduleTranslation.lookupValue (numTeamsLower);
4087
+
4088
+ if (numTeamsUpper)
4089
+ attrs.MaxTeams .front () = moduleTranslation.lookupValue (numTeamsUpper);
4090
+
4091
+ if (teamsThreadLimit)
4092
+ attrs.TeamsThreadLimit .front () =
4093
+ moduleTranslation.lookupValue (teamsThreadLimit);
4094
+
4095
+ if (numThreads)
4096
+ attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4097
+
4098
+ // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4099
+ }
4100
+
3892
4101
static LogicalResult
3893
4102
convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
3894
4103
LLVM::ModuleTranslation &moduleTranslation) {
@@ -3898,7 +4107,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3898
4107
3899
4108
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3900
4109
bool isTargetDevice = ompBuilder->Config .isTargetDevice ();
4110
+
3901
4111
auto parentFn = opInst.getParentOfType <LLVM::LLVMFuncOp>();
4112
+ auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
3902
4113
auto &targetRegion = targetOp.getRegion ();
3903
4114
// Holds the private vars that have been mapped along with the block argument
3904
4115
// that corresponds to the MapInfoOp corresponding to the private var in
@@ -3913,8 +4124,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3913
4124
llvm::DenseMap<Value, Value> mappedPrivateVars;
3914
4125
DataLayout dl = DataLayout (opInst.getParentOfType <ModuleOp>());
3915
4126
SmallVector<Value> mapVars = targetOp.getMapVars ();
3916
- ArrayRef<BlockArgument> mapBlockArgs =
3917
- cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs ();
4127
+ ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs ();
3918
4128
llvm::Function *llvmOutlinedFn = nullptr ;
3919
4129
3920
4130
// TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3928,7 +4138,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3928
4138
// to quickly look up the corresponding map variable, if any for each
3929
4139
// private variable.
3930
4140
if (!targetOp.getPrivateVars ().empty () && !targetOp.getMapVars ().empty ()) {
3931
- auto argIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3932
4141
OperandRange privateVars = targetOp.getPrivateVars ();
3933
4142
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
3934
4143
std::optional<DenseI64ArrayAttr> privateMapIndices =
@@ -4002,7 +4211,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
4002
4211
// Do privatization after moduleTranslation has already recorded
4003
4212
// mapped values.
4004
4213
MutableArrayRef<BlockArgument> privateBlockArgs =
4005
- cast<omp::BlockArgOpenMPOpInterface>(opInst) .getPrivateBlockArgs ();
4214
+ argIface .getPrivateBlockArgs ();
4006
4215
SmallVector<mlir::Value> mlirPrivateVars;
4007
4216
SmallVector<llvm::Value *> llvmPrivateVars;
4008
4217
SmallVector<omp::PrivateClauseOp> privateDecls;
@@ -4085,14 +4294,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
4085
4294
allocaIP, codeGenIP);
4086
4295
};
4087
4296
4088
- // TODO: Populate default and runtime attributes based on the construct and
4089
- // clauses.
4090
4297
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4091
- llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
4092
- /* ExecFlags=*/ llvm::omp::OMP_TGT_EXEC_MODE_GENERIC, /* MaxTeams=*/ {-1 },
4093
- /* MinTeams=*/ 0 , /* MaxThreads=*/ {0 }, /* MinThreads=*/ 0 };
4298
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4299
+ initTargetDefaultAttrs (targetOp, defaultAttrs, isTargetDevice);
4094
4300
4301
+ // Collect host-evaluated values needed to properly launch the kernel from the
4302
+ // host.
4303
+ if (!isTargetDevice)
4304
+ initTargetRuntimeAttrs (builder, moduleTranslation, targetOp, runtimeAttrs);
4305
+
4306
+ // Pass host-evaluated values as parameters to the kernel / host fallback,
4307
+ // except if they are constants. In any case, map the MLIR block argument to
4308
+ // the corresponding LLVM values.
4095
4309
llvm::SmallVector<llvm::Value *, 4 > kernelInput;
4310
+ SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars ();
4311
+ ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs ();
4312
+ for (auto [arg, var] : llvm::zip_equal (hostEvalBlockArgs, hostEvalVars)) {
4313
+ llvm::Value *value = moduleTranslation.lookupValue (var);
4314
+ moduleTranslation.mapValue (arg, value);
4315
+
4316
+ if (!llvm::isa<llvm::Constant>(value))
4317
+ kernelInput.push_back (value);
4318
+ }
4319
+
4096
4320
for (size_t i = 0 ; i < mapVars.size (); ++i) {
4097
4321
// declare target arguments are not passed to kernels as arguments
4098
4322
// TODO: We currently do not handle cases where a member is explicitly
0 commit comments