Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Enable loop-versioning for slices. #120344

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 89 additions & 27 deletions flang/lib/Optimizer/Transforms/LoopVersioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,45 @@ struct ArgsUsageInLoop {
};
} // namespace

static fir::SequenceType getAsSequenceType(mlir::Value *v) {
mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType()));
static fir::SequenceType getAsSequenceType(mlir::Value v) {
mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v.getType()));
return mlir::dyn_cast<fir::SequenceType>(argTy);
}

/// Return the rank and the element size (in bytes) of the given
/// value \p v. If it is not an array or the element type is not
/// supported, then return <0, 0>. Only trivial data types
/// are currently supported.
/// When \p isArgument is true, \p v is assumed to be a function
/// argument. If \p v's type does not look like a type of an assumed
/// shape array, then the function returns <0, 0>.
/// When \p isArgument is false, array types with known innermost
/// dimension are allowed to proceed.
static std::pair<unsigned, size_t>
getRankAndElementSize(const fir::KindMapping &kindMap,
const mlir::DataLayout &dl, mlir::Value v,
bool isArgument = false) {
if (auto seqTy = getAsSequenceType(v)) {
unsigned rank = seqTy.getDimension();
if (rank > 0 &&
(!isArgument ||
seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent())) {
size_t typeSize = 0;
mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(v.getType());
if (fir::isa_trivial(elementType)) {
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
v.getLoc(), elementType, dl, kindMap);
typeSize = llvm::alignTo(eleSize, eleAlign);
}
if (typeSize)
return {rank, typeSize};
}
}

LLVM_DEBUG(llvm::dbgs() << "Unsupported rank/type: " << v << '\n');
return {0, 0};
}

/// if a value comes from a fir.declare, follow it to the original source,
/// otherwise return the value
static mlir::Value unwrapFirDeclare(mlir::Value val) {
Expand All @@ -160,12 +194,48 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
return val;
}

/// Return true, if \p rebox operation keeps the input array
/// continuous in the innermost dimension, if it is initially continuous
/// in the innermost dimension.
static bool reboxPreservesContinuity(fir::ReboxOp rebox) {
// If slicing is not involved, then the rebox does not affect
// the continuity of the array.
auto sliceArg = rebox.getSlice();
if (!sliceArg)
return true;

// A slice with step=1 in the innermost dimension preserves
// the continuity of the array in the innermost dimension.
if (auto sliceOp =
mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) {
if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) {
auto triples = sliceOp.getTriples();
if (triples.size() > 2)
if (auto innermostStep = fir::getIntIfConstant(triples[2]))
if (*innermostStep == 1)
return true;
}

LLVM_DEBUG(llvm::dbgs()
<< "REBOX with slicing may produce non-contiguous array: "
<< sliceOp << '\n'
<< rebox << '\n');
return false;
}

LLVM_DEBUG(llvm::dbgs() << "REBOX with unknown slice" << sliceArg << '\n'
<< rebox << '\n');
return false;
}

/// if a value comes from a fir.rebox, follow the rebox to the original source,
/// of the value, otherwise return the value
static mlir::Value unwrapReboxOp(mlir::Value val) {
// don't support reboxes of reboxes
if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>())
while (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) {
if (!reboxPreservesContinuity(rebox))
break;
val = rebox.getBox();
}
return val;
}

Expand Down Expand Up @@ -257,25 +327,10 @@ void LoopVersioningPass::runOnOperation() {
continue;
}

if (auto seqTy = getAsSequenceType(&arg)) {
unsigned rank = seqTy.getDimension();
if (rank > 0 &&
seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) {
size_t typeSize = 0;
mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType());
if (mlir::isa<mlir::FloatType>(elementType) ||
mlir::isa<mlir::IntegerType>(elementType) ||
mlir::isa<mlir::ComplexType>(elementType)) {
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
arg.getLoc(), elementType, *dl, kindMap);
typeSize = llvm::alignTo(eleSize, eleAlign);
}
if (typeSize)
argsOfInterest.push_back({arg, typeSize, rank, {}});
else
LLVM_DEBUG(llvm::dbgs() << "Type not supported\n");
}
}
auto [rank, typeSize] =
getRankAndElementSize(kindMap, *dl, arg, /*isArgument=*/true);
if (rank != 0 && typeSize != 0)
argsOfInterest.push_back({arg, typeSize, rank, {}});
}

if (argsOfInterest.empty()) {
Expand Down Expand Up @@ -326,6 +381,13 @@ void LoopVersioningPass::runOnOperation() {
if (arrayCoor.getSlice())
argsInLoop.cannotTransform.insert(a.arg);

// We need to compute the rank and element size
// based on the operand, not the original argument,
// because array slicing may affect it.
std::tie(a.rank, a.size) = getRankAndElementSize(kindMap, *dl, a.arg);
if (a.rank == 0 || a.size == 0)
argsInLoop.cannotTransform.insert(a.arg);

if (argsInLoop.cannotTransform.contains(a.arg)) {
// Remove any previously recorded usage, if any.
argsInLoop.usageInfo.erase(a.arg);
Expand Down Expand Up @@ -416,8 +478,8 @@ void LoopVersioningPass::runOnOperation() {
mlir::Location loc = builder.getUnknownLoc();
mlir::IndexType idxTy = builder.getIndexType();

LLVM_DEBUG(llvm::dbgs() << "Module Before transformation:");
LLVM_DEBUG(module->dump());
LLVM_DEBUG(llvm::dbgs() << "Func Before transformation:\n");
LLVM_DEBUG(func->dump());

LLVM_DEBUG(llvm::dbgs() << "loopsOfInterest: " << loopsOfInterest.size()
<< "\n");
Expand Down Expand Up @@ -551,8 +613,8 @@ void LoopVersioningPass::runOnOperation() {
}
}

LLVM_DEBUG(llvm::dbgs() << "After transform:\n");
LLVM_DEBUG(module->dump());
LLVM_DEBUG(llvm::dbgs() << "Func After transform:\n");
LLVM_DEBUG(func->dump());

LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
}
Loading
Loading