Skip to content

Commit

Permalink
Add calculation of reduction size for teams op.
Browse files Browse the repository at this point in the history
  • Loading branch information
jsjodin committed Sep 24, 2024
1 parent 97e4085 commit 3861586
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3713,6 +3713,29 @@ static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
return op->getParentOfType<OpTy>();
}

static uint64_t getTypeByteSize(mlir::Type type, DataLayout dl) {
uint64_t sizeInBits = dl.getTypeSizeInBits(type);
uint64_t sizeInBytes = sizeInBits / 8;
return sizeInBytes;
}

template <typename OpTy>
static uint64_t getReductionDataSize(OpTy &op) {
if (op.getNumReductionVars() > 0) {
assert(op.getNumReductionVars() &&
"Only 1 reduction variable currently supported");
mlir::Type reductionVarTy = op.getReductionVars()[0].getType();
Operation *opp = op.getOperation();
DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
return getTypeByteSize(reductionVarTy, dl);
}
return 0;
}

static uint64_t getTeamsReductionDataSize(mlir::omp::TeamsOp &teamsOp) {
return getReductionDataSize<mlir::omp::TeamsOp>(teamsOp);
}

/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
/// values as stated by the corresponding clauses, if constant.
///
Expand Down Expand Up @@ -3807,34 +3830,13 @@ static void initTargetDefaultBounds(

// Calculate reduction data size, limited to single reduction variable
// for now.
// FIXME: This treats 'DO SIMD' as if it was a 'DO' construct. Reductions
// on other constructs apart from 'DO' aren't considered either.
int32_t reductionDataSize = 0;
if (isGPU && innermostCapturedOmpOp) {
if (auto loopNestOp =
mlir::dyn_cast<mlir::omp::LoopNestOp>(innermostCapturedOmpOp)) {
// FIXME: This treats 'DO SIMD' as if it was a 'DO' construct. Reductions
// on other constructs apart from 'DO' aren't considered either.
mlir::omp::WsloopOp wsloopOp = nullptr;
SmallVector<mlir::omp::LoopWrapperInterface> wrappers;
loopNestOp.gatherWrappers(wrappers);
for (auto wrapper : wrappers) {
wsloopOp = mlir::dyn_cast<mlir::omp::WsloopOp>(*wrapper);
if (wsloopOp)
break;
}
if (wsloopOp) {
if (wsloopOp.getNumReductionVars() > 0) {
assert(wsloopOp.getNumReductionVars() &&
"Only 1 reduction variable currently supported");
mlir::Value reductionVar = wsloopOp.getReductionVars()[0];
DataLayout dl =
DataLayout(innermostCapturedOmpOp->getParentOfType<ModuleOp>());

mlir::Type reductionVarTy = reductionVar.getType();
uint64_t sizeInBits = dl.getTypeSizeInBits(reductionVarTy);
uint64_t sizeInBytes = sizeInBits / 8;
reductionDataSize = sizeInBytes;
}
}
if (auto teamsOp =
castOrGetParentOfType<omp::TeamsOp>(innermostCapturedOmpOp)) {
reductionDataSize = getTeamsReductionDataSize(teamsOp);
}
}

Expand Down

0 comments on commit 3861586

Please sign in to comment.