Skip to content

Commit

Permalink
[flang] Added definition of hlfir.cshift operation. (#118732)
Browse files Browse the repository at this point in the history
CSHIFT intrinsic will be lowered to this operation, which
then can be optimized as inline sequence or lowered into
a runtime call.
  • Loading branch information
vzakhari authored Dec 9, 2024
1 parent 10f315d commit 1ca3927
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 0 deletions.
3 changes: 3 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ mlir::Value genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc,
/// This has to be cleaned up, when HLFIR is the default.
bool mayHaveAllocatableComponent(mlir::Type ty);

/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);

} // namespace hlfir

#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ def IsPolymorphicObjectPred
def AnyPolymorphicObject : Type<IsPolymorphicObjectPred,
"any polymorphic object">;

def IsFortranIntegerScalarOrArrayPred
: CPred<"::hlfir::isFortranIntegerScalarOrArrayObject($_self)">;
def AnyFortranIntegerScalarOrArrayObject
: Type<IsFortranIntegerScalarOrArrayPred,
"A scalar or array object containing integers">;

def hlfir_CharExtremumPredicateAttr : I32EnumAttr<
"CharExtremumPredicate", "",
[
Expand Down
21 changes: 21 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,27 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
let hasVerifier = 1;
}

def hlfir_CShiftOp
: hlfir_Op<
"cshift", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "CSHIFT transformational intrinsic";
let description = [{
Circular shift of an array
}];

let arguments = (ins AnyFortranArrayObject:$array,
AnyFortranIntegerScalarOrArrayObject:$shift,
Optional<AnyIntegerType>:$dim);

let results = (outs hlfir_ExprType);

let assemblyFormat = [{
$array $shift (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

// An allocation effect is needed because the value produced by the associate
// is "deallocated" by hlfir.end_associate (the end_associate must not be
// removed, and there must be only one hlfir.end_associate).
Expand Down
9 changes: 9 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,12 @@ mlir::Type hlfir::getExprType(mlir::Type variableType) {
return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
isPolymorphic);
}

bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
if (isBoxAddressType(type))
return false;

mlir::Type unwrappedType = fir::unwrapPassByRefType(fir::unwrapRefType(type));
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}
102 changes: 102 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,108 @@ void hlfir::MatmulTransposeOp::getEffects(
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// CShiftOp
//===----------------------------------------------------------------------===//

llvm::LogicalResult hlfir::CShiftOp::verify() {
mlir::Value array = getArray();
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(array.getType()));
llvm::ArrayRef<int64_t> inShape = arrayTy.getShape();
std::size_t arrayRank = inShape.size();
mlir::Type eleTy = arrayTy.getEleTy();
hlfir::ExprType resultTy = mlir::cast<hlfir::ExprType>(getResult().getType());
llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
std::size_t resultRank = resultShape.size();
mlir::Type resultEleTy = resultTy.getEleTy();
mlir::Value shift = getShift();
mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType());

if (eleTy != resultEleTy) {
if (mlir::isa<fir::CharacterType>(eleTy) &&
mlir::isa<fir::CharacterType>(resultEleTy)) {
auto eleCharTy = mlir::cast<fir::CharacterType>(eleTy);
auto resultCharTy = mlir::cast<fir::CharacterType>(resultEleTy);
if (eleCharTy.getFKind() != resultCharTy.getFKind())
return emitOpError("kind mismatch between input and output arrays");
if (eleCharTy.getLen() != fir::CharacterType::unknownLen() &&
resultCharTy.getLen() != fir::CharacterType::unknownLen() &&
eleCharTy.getLen() != resultCharTy.getLen())
return emitOpError(
"character LEN mismatch between input and output arrays");
} else {
return emitOpError(
"input and output arrays should have the same element type");
}
}

if (arrayRank != resultRank)
return emitOpError("input and output arrays should have the same rank");

constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (auto [inDim, resultDim] : llvm::zip(inShape, resultShape))
if (inDim != unknownExtent && resultDim != unknownExtent &&
inDim != resultDim)
return emitOpError(
"output array's shape conflicts with the input array's shape");

int64_t dimVal = -1;
if (!getDim())
dimVal = 1;
else if (auto dim = fir::getIntIfConstant(getDim()))
dimVal = *dim;

// The DIM argument may be statically invalid (e.g. exceed the
// input array rank) in dead code after constant propagation,
// so avoid some checks unless useStrictIntrinsicVerifier is true.
if (useStrictIntrinsicVerifier && dimVal != -1) {
if (dimVal < 1)
return emitOpError("DIM must be >= 1");
if (dimVal > static_cast<int64_t>(arrayRank))
return emitOpError("DIM must be <= input array's rank");
}

if (auto shiftSeqTy = mlir::dyn_cast<fir::SequenceType>(shiftTy)) {
// SHIFT is an array. Verify the rank and the shape (if DIM is constant).
llvm::ArrayRef<int64_t> shiftShape = shiftSeqTy.getShape();
std::size_t shiftRank = shiftShape.size();
if (shiftRank != arrayRank - 1)
return emitOpError(
"SHIFT's rank must be 1 less than the input array's rank");

if (useStrictIntrinsicVerifier && dimVal != -1) {
// SHIFT's shape must be [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)],
// where [d(1), d(2), ..., d(n)] is the shape of the ARRAY.
int64_t arrayDimIdx = 0;
int64_t shiftDimIdx = 0;
for (auto shiftDim : shiftShape) {
if (arrayDimIdx == dimVal - 1)
++arrayDimIdx;

if (inShape[arrayDimIdx] != unknownExtent &&
shiftDim != unknownExtent && inShape[arrayDimIdx] != shiftDim)
return emitOpError("SHAPE(ARRAY)(" + llvm::Twine(arrayDimIdx + 1) +
") must be equal to SHAPE(SHIFT)(" +
llvm::Twine(shiftDimIdx + 1) +
"): " + llvm::Twine(inShape[arrayDimIdx]) +
" != " + llvm::Twine(shiftDim));
++arrayDimIdx;
++shiftDimIdx;
}
}
}

return mlir::success();
}

void hlfir::CShiftOp::getEffects(
llvm::SmallVectorImpl<
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
&effects) {
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// AssociateOp
//===----------------------------------------------------------------------===//
Expand Down
75 changes: 75 additions & 0 deletions flang/test/HLFIR/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -1348,3 +1348,78 @@ func.func @bad_eval_in_mem_3() {
}
return
}

// -----

func.func @bad_cshift1(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op input and output arrays should have the same element type}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?x?xf32>
return
}

// -----

func.func @bad_cshift2(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op input and output arrays should have the same rank}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?xi32>
return
}

// -----

func.func @bad_cshift3(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op output array's shape conflicts with the input array's shape}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, i32) -> !hlfir.expr<2x3xi32>
return
}

// -----

func.func @bad_cshift4(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'hlfir.cshift' op DIM must be >= 1}}
%0 = hlfir.cshift %arg0 %arg1 dim %c0 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift5(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
%c10 = arith.constant 10 : index
// expected-error@+1 {{'hlfir.cshift' op DIM must be <= input array's rank}}
%0 = hlfir.cshift %arg0 %arg1 dim %c10 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift6(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
// expected-error@+1 {{'hlfir.cshift' op SHIFT's rank must be 1 less than the input array's rank}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift7(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<3xi32>) {
%c1 = arith.constant 1 : index
// expected-error@+1 {{'hlfir.cshift' op SHAPE(ARRAY)(2) must be equal to SHAPE(SHIFT)(1): 2 != 3}}
%0 = hlfir.cshift %arg0 %arg1 dim %c1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<3xi32>, index) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift8(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op kind mismatch between input and output arrays}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x!fir.char<1,?>>, i32) -> !hlfir.expr<?x!fir.char<2,?>>
return
}

// -----

func.func @bad_cshift9(%arg0: !hlfir.expr<?x!fir.char<1,1>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op character LEN mismatch between input and output arrays}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x!fir.char<1,1>>, i32) -> !hlfir.expr<?x!fir.char<1,2>>
return
}

0 comments on commit 1ca3927

Please sign in to comment.