Skip to content

Ajm/flang volatile attr #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
mlir::Block *getAllocaBlock();

/// Safely create a reference type to the type `eleTy`.
mlir::Type getRefType(mlir::Type eleTy);
mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);

/// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);
Expand Down
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) {
fir::LLVMPointerType>(t);
}

inline bool isa_volatile_ref_type(mlir::Type t) {
if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(t))
return refTy.isVolatile();
return false;
}

/// Is `t` a boxed type?
inline bool isa_box_type(mlir::Type t) {
return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);
Expand Down
10 changes: 7 additions & 3 deletions flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,22 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
The type of a reference to an entity in memory.
}];

let parameters = (ins "mlir::Type":$eleTy);
let parameters = (ins
"mlir::Type":$eleTy,
DefaultValuedParameter<"bool", "false">:$isVol);

let skipDefaultBuilders = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol), [{
return Base::get(elementType.getContext(), elementType, isVol);
}]>,
];

let extraClassDeclaration = [{
mlir::Type getElementType() const { return getEleTy(); }
bool isVolatile() const { return (bool)getIsVol(); }
static llvm::StringRef getVolatileKeyword() { return "volatile"; }
}];

let genVerifyDecl = 1;
Expand Down
1 change: 0 additions & 1 deletion flang/lib/Lower/CallInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,6 @@ class Fortran::lower::CallInterfaceImpl {
if (obj.attrs.test(Attrs::Value))
isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
if (obj.attrs.test(Attrs::Volatile)) {
TODO(loc, "VOLATILE in procedure interface");
addMLIRAttr(fir::getVolatileAttrName());
}
// obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
Expand Down
48 changes: 42 additions & 6 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,37 @@ class HlfirDesignatorBuilder {
designatorNode, getConverter().getFoldingContext(),
/*namedConstantSectionsAreAlwaysContiguous=*/false))
return fir::BoxType::get(resultValueType);

bool isVolatile = false;

// Check if the base type is volatile
if (partInfo.base.has_value()) {
mlir::Type baseType = partInfo.base.value().getType();
isVolatile = fir::isa_volatile_ref_type(baseType);
}

auto isVolatileSymbol = [](const Fortran::semantics::Symbol &symbol) {
return symbol.GetUltimate().attrs().test(
Fortran::semantics::Attr::VOLATILE);
};

// Check if this should be a volatile reference
if constexpr (std::is_same_v<std::decay_t<T>,
Fortran::evaluate::SymbolRef>) {
if (isVolatileSymbol(designatorNode.get()))
isVolatile = true;
} else if constexpr (std::is_same_v<std::decay_t<T>,
Fortran::evaluate::Component>) {
if (isVolatileSymbol(designatorNode.GetLastSymbol()))
isVolatile = true;
}

// If it's a reference to a ref, account for it
if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
resultValueType = refTy.getEleTy();

// Other designators can be handled as raw addresses.
return fir::ReferenceType::get(resultValueType);
return fir::ReferenceType::get(resultValueType, isVolatile);
}

template <typename T>
Expand Down Expand Up @@ -414,10 +443,16 @@ class HlfirDesignatorBuilder {
.Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
return fir::SequenceType::get(seqTy.getShape(), newEleTy);
})
.Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
fir::ClassType>([&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
// TODO: handle volatility for other types
.Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
[&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
})
.Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
return fir::ReferenceType::get(
changeElementType(refTy.getEleTy(), newEleTy),
refTy.isVolatile());
})
.Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; });
}
Expand Down Expand Up @@ -1808,6 +1843,7 @@ class HlfirBuilder {
auto &expr = std::get<const Fortran::lower::SomeExpr &>(iter);
auto &baseOp = std::get<hlfir::EntityWithAttributes>(iter);
std::string name = converter.getRecordTypeFieldName(sym);
const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType());

// Generate DesignateOp for the component.
// The designator's result type is just a reference to the component type,
Expand All @@ -1818,7 +1854,7 @@ class HlfirBuilder {
assert(compType && "failed to retrieve component type");
mlir::Value compShape =
designatorBuilder.genComponentShape(sym, compType);
mlir::Type designatorType = builder.getRefType(compType);
mlir::Type designatorType = builder.getRefType(compType, isVolatile);

mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
llvm::SmallVector<mlir::Value, 1> typeParams;
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
return modOp.lookupSymbol<fir::GlobalOp>(name);
}

mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
return fir::ReferenceType::get(eleTy);
return fir::ReferenceType::get(eleTy, isVolatile);
}

mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,8 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
} else if (fir::isRecordWithTypeParameters(eleTy)) {
return fir::BoxType::get(eleTy);
}
return fir::ReferenceType::get(eleTy);
const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
return fir::ReferenceType::get(eleTy, isVolatile);
}

mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
Expand Down
17 changes: 12 additions & 5 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3218,6 +3218,8 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::ConversionPatternRewriter &rewriter) const override {

mlir::Type llvmLoadTy = convertObjectType(load.getType());
const bool isVolatile =
fir::isa_volatile_ref_type(load.getMemref().getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
// fir.box is a special case because it is considered an ssa value in
// fir, but it is lowered as a pointer to a descriptor. So
Expand Down Expand Up @@ -3247,16 +3249,18 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);

if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
memcpy.setTBAATags(*optionalTag);
else
attachTBAATag(memcpy, boxTy, boxTy, nullptr);
rewriter.replaceOp(load, newBoxStorage);
} else {
// TODO: are we losing any attributes from the load op?
auto memref = adaptor.getOperands()[0];
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile);
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
loadOp.setTBAATags(*optionalTag);
else
Expand Down Expand Up @@ -3534,17 +3538,20 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
mlir::Value llvmValue = adaptor.getValue();
mlir::Value llvmMemref = adaptor.getMemref();
mlir::LLVM::AliasAnalysisOpInterface newOp;
const bool isVolatile =
fir::isa_volatile_ref_type(store.getMemref().getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
// Always use memcpy because LLVM is not as effective at optimizing
// aggregate loads/stores as it is optimizing memcpy.
TypePair boxTypePair{boxTy, llvmBoxTy};
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue,
boxSize, isVolatile);
} else {
newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref,
/*alignment=*/0, isVolatile);
}
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
newOp.setTBAATags(*optionalTag);
Expand Down
57 changes: 43 additions & 14 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,12 +649,17 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
.Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
return fir::SequenceType::get(seqTy.getShape(), newElementType);
})
.Case<fir::PointerType, fir::HeapType, fir::ReferenceType,
fir::ClassType>([&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(
changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
.Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
auto newEleTy = changeElementType(refTy.getEleTy(), newElementType,
turnBoxIntoClass);
return fir::ReferenceType::get(newEleTy, refTy.isVolatile());
})
.Case<fir::PointerType, fir::HeapType, fir::ClassType>(
[&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(changeElementType(t.getEleTy(), newElementType,
turnBoxIntoClass));
})
.Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
mlir::Type newInnerType =
changeElementType(t.getEleTy(), newElementType, false);
Expand Down Expand Up @@ -1057,18 +1062,38 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
// ReferenceType
//===----------------------------------------------------------------------===//

// `ref` `<` type `>`
// `ref` `<` type (`, volatile` $volatile^)? (`, async` $async^)? `>`
mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
return parseTypeSingleton<fir::ReferenceType>(parser);
if (parser.parseLess())
return {};

mlir::Type eleTy;
if (parser.parseType(eleTy))
return {};

bool isVolatile = false;
if (!parser.parseOptionalComma()) {
if (parser.parseKeyword(getVolatileKeyword())) {
return {};
}
isVolatile = true;
}

if (parser.parseGreater())
return {};
return get(eleTy, isVolatile);
}

void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
printer << "<" << getEleTy() << '>';
printer << "<" << getEleTy();
if (isVolatile())
printer << ", " << getVolatileKeyword();
printer << '>';
}

llvm::LogicalResult fir::ReferenceType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type eleTy,
bool isVolatile) {
if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
ReferenceType, TypeDescType>(eleTy))
return emitError() << "cannot build a reference to type: " << eleTy << '\n';
Expand Down Expand Up @@ -1319,11 +1344,15 @@ changeTypeShape(mlir::Type type,
return fir::SequenceType::get(*newShape, seqTy.getEleTy());
return seqTy.getEleTy();
})
.Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
fir::ClassType>([&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
.Case<fir::ReferenceType>([&](fir::ReferenceType rt) -> mlir::Type {
return fir::ReferenceType::get(changeTypeShape(rt.getEleTy(), newShape),
rt.isVolatile());
})
.Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
[&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
})
.Default([&](mlir::Type t) -> mlir::Type {
assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
llvm::isa<mlir::NoneType>(t)) &&
Expand Down
7 changes: 7 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
auto nameAttr = builder.getStringAttr(uniq_name);
mlir::Type inputType = memref.getType();
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
if (fortran_attrs && mlir::isa<fir::ReferenceType>(inputType) &&
bitEnumContainsAny(fortran_attrs.getFlags(),
fir::FortranVariableFlagsEnum::fortran_volatile)) {
auto refType = mlir::cast<fir::ReferenceType>(inputType);
inputType = fir::ReferenceType::get(refType.getEleTy(), true);
memref = builder.create<fir::ConvertOp>(memref.getLoc(), inputType, memref);
}
mlir::Type hlfirVariableType =
getHLFIRVariableType(inputType, hasExplicitLbs);
build(builder, result, {hlfirVariableType, inputType}, memref, shape,
Expand Down
7 changes: 5 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,9 @@ class DesignateOpConversion
firstElementIndices.push_back(indices[i]);
i = i + (isTriplet ? 3 : 1);
}
mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy);
mlir::Type originalDesignateType = designate.getResult().getType();
const bool isVolatile = fir::isa_volatile_ref_type(originalDesignateType);
mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy, isVolatile);
base = builder.create<fir::ArrayCoorOp>(
loc, arrayCoorType, base, shape,
/*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters);
Expand All @@ -441,6 +443,7 @@ class DesignateOpConversion
TODO(loc, "hlfir::designate load of pointer or allocatable");

mlir::Type designateResultType = designate.getResult().getType();
const bool isVolatile = fir::isa_volatile_ref_type(designateResultType);
llvm::SmallVector<mlir::Value> firBaseTypeParameters;
auto [base, shape] = hlfir::genVariableFirBaseShapeAndParams(
loc, builder, baseEntity, firBaseTypeParameters);
Expand All @@ -464,7 +467,7 @@ class DesignateOpConversion
mlir::Type componentType =
mlir::cast<fir::RecordType>(baseEleTy).getType(
designate.getComponent().value());
mlir::Type coorTy = fir::ReferenceType::get(componentType);
mlir::Type coorTy = fir::ReferenceType::get(componentType, isVolatile);
base = builder.create<fir::CoordinateOp>(loc, coorTy, base, fieldIndex);
if (mlir::isa<fir::BaseBoxType>(componentType)) {
auto variableInterface = mlir::cast<fir::FortranVariableOpInterface>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,8 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Type resultElemTy =
hlfir::getFortranElementType(resultArr.getType());
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
mlir::Type returnRefTy = builder.getRefType(
resultElemTy, fir::isa_volatile_ref_type(flagRef.getType()));
mlir::IndexType idxTy = builder.getIndexType();

for (unsigned int i = 0; i < rank; ++i) {
Expand All @@ -1153,7 +1154,8 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
const mlir::Type &resultElemType, mlir::Value resultArr,
mlir::Value index) {
mlir::Type resultRefTy = builder.getRefType(resultElemType);
mlir::Type resultRefTy = builder.getRefType(
resultElemType, fir::isa_volatile_ref_type(resultArr.getType()));
mlir::Value oneIdx =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
Expand All @@ -1162,8 +1164,9 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
};

// Initialize the result
const bool isVolatile = fir::isa_volatile_ref_type(resultArr.getType());
mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
mlir::Type resultRefTy = builder.getRefType(resultElemTy);
mlir::Type resultRefTy = builder.getRefType(resultElemTy, isVolatile);
mlir::Value returnValue =
builder.createIntegerConstant(loc, resultElemTy, 0);
for (unsigned int i = 0; i < rank; ++i) {
Expand Down
18 changes: 18 additions & 0 deletions flang/test/Fir/volatile.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s -o - | FileCheck %s
// CHECK: llvm.store volatile %{{.+}}, %{{.+}} : i32, !llvm.ptr
// CHECK: %{{.+}} = llvm.load volatile %{{.+}} : !llvm.ptr -> i32
func.func @foo() {
%true = arith.constant true
%false = arith.constant false
%0 = fir.alloca !fir.logical<4> {bindc_name = "a", uniq_name = "_QFEa"}
%1 = fir.convert %0 : (!fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>, volatile>
%2 = fir.alloca !fir.logical<4> {bindc_name = "b", uniq_name = "_QFEb"}
%3 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
%4 = fir.convert %false : (i1) -> !fir.logical<4>
fir.store %4 to %1 : !fir.ref<!fir.logical<4>, volatile>
%5 = fir.load %1 : !fir.ref<!fir.logical<4>, volatile>
fir.store %5 to %2 : !fir.ref<!fir.logical<4>>
%6 = fir.convert %true : (i1) -> !fir.logical<4>
fir.store %6 to %1 : !fir.ref<!fir.logical<4>, volatile>
return
}
11 changes: 11 additions & 0 deletions flang/test/Integration/volatile.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
! RUN: bbc %s -o - | FileCheck %s
logical, volatile :: a
logical :: b
integer :: i
a = .false.
b = a
a = .true.
end

! CHECK: %{{.+}} = fir.load %{{.+}} : !fir.ref<!fir.logical<4>, volatile>
! CHECK: hlfir.assign %{{.+}} to %{{.+}} : !fir.logical<4>, !fir.ref<!fir.logical<4>, volatile>
Loading