Skip to content

Commit

Permalink
[flang] AArch64 ABI for BIND(C) VALUE parameters (#118305)
Browse files Browse the repository at this point in the history
This patch adds handling for derived type VALUE parameters in BIND(C)
functions for AArch64.
  • Loading branch information
DavidTruby authored Dec 18, 2024
1 parent 3666de9 commit 44aa476
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 26 deletions.
146 changes: 120 additions & 26 deletions flang/lib/Optimizer/CodeGen/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ struct TargetX86_64Win : public GenericTarget<TargetX86_64Win> {
//===----------------------------------------------------------------------===//

namespace {
// AArch64 procedure call standard:
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
struct TargetAArch64 : public GenericTarget<TargetAArch64> {
using GenericTarget::GenericTarget;

Expand Down Expand Up @@ -826,7 +828,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
return marshal;
}

// Flatten a RecordType::TypeList containing more record types or array types
// Flatten a RecordType::TypeList containing more record types or array type
static std::optional<std::vector<mlir::Type>>
flattenTypeList(const RecordType::TypeList &types) {
std::vector<mlir::Type> flatTypes;
Expand Down Expand Up @@ -870,52 +872,144 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {

// Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
// HFA is a record type with up to 4 floating-point members of the same type.
static bool isHFA(fir::RecordType ty) {
static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
RecordType::TypeList types = ty.getTypeList();
if (types.empty() || types.size() > 4)
return false;
return std::nullopt;

std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
if (!flatTypes || flatTypes->size() > 4) {
return false;
return std::nullopt;
}

if (!isa_real(flatTypes->front())) {
return false;
return std::nullopt;
}

return llvm::all_equal(*flatTypes);
return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
: std::nullopt;
}

// AArch64 procedure call ABI:
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
CodeGenSpecifics::Marshalling
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
CodeGenSpecifics::Marshalling marshal;
struct NRegs {
int n{0};
bool isSimd{false};
};

if (isHFA(ty)) {
// Just return the existing record type
marshal.emplace_back(ty, AT{});
return marshal;
NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
if (std::optional<int> size = usedRegsForHFA(type))
return {*size, true};

auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
loc, type, getDataLayout(), kindMap);

if (size <= 16)
return {static_cast<int>((size + 7) / 8), false};

// Pass on the stack, i.e. no registers used
return {};
}

NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
return llvm::TypeSwitch<mlir::Type, NRegs>(type)
.Case<mlir::IntegerType>([&](auto intTy) {
return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
})
.Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
.Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
.Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
.Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
.Case<fir::SequenceType>([&](auto ty) {
assert(ty.getShape().size() == 1 &&
"invalid array dimensions in BIND(C)");
NRegs nregs = usedRegsForType(loc, ty.getEleTy());
nregs.n *= ty.getShape()[0];
return nregs;
})
.Case<fir::RecordType>(
[&](auto ty) { return usedRegsForRecordType(loc, ty); })
.Case<fir::VectorType>([&](auto) {
TODO(loc, "passing vector argument to C by value is not supported");
return NRegs{};
});
}

bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
const Marshalling &previousArguments) const {
int availIntRegisters = 8;
int availSIMDRegisters = 8;

// Check previous arguments to see how many registers are used already
for (auto [type, attr] : previousArguments) {
if (availIntRegisters <= 0 || availSIMDRegisters <= 0)
break;

if (attr.isByVal())
continue; // Previous argument passed on the stack

NRegs nregs = usedRegsForType(loc, type);
if (nregs.isSimd)
availSIMDRegisters -= nregs.n;
else
availIntRegisters -= nregs.n;
}

auto [size, align] =
NRegs nregs = usedRegsForRecordType(loc, type);

if (nregs.isSimd)
return nregs.n <= availSIMDRegisters;

return nregs.n <= availIntRegisters;
}

CodeGenSpecifics::Marshalling
passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
CodeGenSpecifics::Marshalling marshal;
auto sizeAndAlign =
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
// The stack is always 8 byte aligned
unsigned short align =
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
marshal.emplace_back(fir::ReferenceType::get(ty),
AT{align, /*byval=*/!isResult, /*sret=*/isResult});
return marshal;
}

// return in registers if size <= 16 bytes
if (size <= 16) {
std::size_t dwordSize = (size + 7) / 8;
auto newTy = fir::SequenceType::get(
dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
marshal.emplace_back(newTy, AT{});
return marshal;
CodeGenSpecifics::Marshalling
structType(mlir::Location loc, fir::RecordType type, bool isResult) const {
NRegs nregs = usedRegsForRecordType(loc, type);

// If the type needs no registers it must need to be passed on the stack
if (nregs.n == 0)
return passOnTheStack(loc, type, isResult);

CodeGenSpecifics::Marshalling marshal;

mlir::Type pcsType;
if (nregs.isSimd) {
pcsType = type;
} else {
pcsType = fir::SequenceType::get(
nregs.n, mlir::IntegerType::get(type.getContext(), 64));
}

unsigned short stackAlign = std::max<unsigned short>(align, 8u);
marshal.emplace_back(fir::ReferenceType::get(ty),
AT{stackAlign, false, true});
marshal.emplace_back(pcsType, AT{});
return marshal;
}

CodeGenSpecifics::Marshalling
structArgumentType(mlir::Location loc, fir::RecordType ty,
const Marshalling &previousArguments) const override {
if (!hasEnoughRegisters(loc, ty, previousArguments)) {
return passOnTheStack(loc, ty, /*isResult=*/false);
}

return structType(loc, ty, /*isResult=*/false);
}

CodeGenSpecifics::Marshalling
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
return structType(loc, ty, /*isResult=*/true);
}
};
} // namespace

Expand Down
73 changes: 73 additions & 0 deletions flang/test/Fir/struct-passing-aarch64-byval.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s

// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>)
func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>)
// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>)
func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>)
func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>)
// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>)
func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>)

// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)

// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>)
func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>)
func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>)

// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>)
func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>)
func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>)


// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.array<2xi64>,
// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>})
func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>,
!fir.type<int_max{i:i64,j:i64}>)
// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>})
func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)

// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>})
func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>)

0 comments on commit 44aa476

Please sign in to comment.