Skip to content

[LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. #83859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
Expand Down Expand Up @@ -147,6 +149,14 @@ class VectorLegalizer {
void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);

bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
SmallVectorImpl<SDValue> &Results);
bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32,
RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80,
RTLIB::Libcall Call_F128,
RTLIB::Libcall Call_PPCF128,
SmallVectorImpl<SDValue> &Results);

void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);

/// Implements vector promotion.
Expand Down Expand Up @@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VP_MERGE:
Results.push_back(ExpandVP_MERGE(Node));
return;
case ISD::FREM:
if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64,
RTLIB::REM_F80, RTLIB::REM_F128,
RTLIB::REM_PPCF128, Results))
return;

break;
}

SDValue Unrolled = DAG.UnrollVectorOp(Node);
Expand Down Expand Up @@ -1842,6 +1859,117 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
Results.push_back(Result);
}

// Try to expand libm nodes into vector math routine calls. Callers provide the
// LibFunc equivalent of the passed in Node, which is used to lookup mappings
// within TargetLibraryInfo. The only mappings considered are those where the
// result and all operands are the same vector type. While predicated nodes are
// not supported, we will emit calls to masked routines by passing in an all
// true mask.
bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
SmallVectorImpl<SDValue> &Results) {
// Chain must be propagated but currently strict fp operations are down
// converted to their none strict counterpart.
assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");

const char *LCName = TLI.getLibcallName(LC);
if (!LCName)
return false;
LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");

EVT VT = Node->getValueType(0);
ElementCount VL = VT.getVectorElementCount();

// Lookup a vector function equivalent to the specified libcall. Prefer
// unmasked variants but we will generate a mask if need be.
const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
if (!VD)
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/true);
if (!VD)
return false;

LLVMContext *Ctx = DAG.getContext();
Type *Ty = VT.getTypeForEVT(*Ctx);
Type *ScalarTy = Ty->getScalarType();

// Construct a scalar function type based on Node's operands.
SmallVector<Type *, 8> ArgTys;
for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
assert(Node->getOperand(i).getValueType() == VT &&
"Expected matching vector types!");
ArgTys.push_back(ScalarTy);
}
FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);

// Generate call information for the vector function.
const std::string MangledName = VD->getVectorFunctionABIVariantString();
auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
if (!OptVFInfo)
return false;

LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
<< "\n");

// Sanity check just in case OptVFInfo has unexpected parameters.
if (OptVFInfo->Shape.Parameters.size() !=
Node->getNumOperands() + VD->isMasked())
return false;

// Collect vector call operands.

SDLoc DL(Node);
TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
Entry.IsSExt = false;
Entry.IsZExt = false;

unsigned OpNum = 0;
for (auto &VFParam : OptVFInfo->Shape.Parameters) {
if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
Args.push_back(Entry);
continue;
}

// Only vector operands are supported.
if (VFParam.ParamKind != VFParamKind::Vector)
return false;

Entry.Node = Node->getOperand(OpNum++);
Entry.Ty = Ty;
Args.push_back(Entry);
}

// Emit a call to the vector function.
SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
TLI.getPointerTy(DAG.getDataLayout()));
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL)
.setChain(DAG.getEntryNode())
.setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));

std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
Results.push_back(CallResult.first);
return true;
}

/// Try to expand the node to a vector libcall based on the result type.
bool VectorLegalizer::tryExpandVecMathCall(
SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the Libcall variants are parameters to this function, instead of being hardcoded? (Looking at other functions like DAGTypeLegalizer::SoftenFloatRes_FREM I see the list being hardcoded).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm expecting this function to be used by other ISD nodes (e.g. ISD::FSIN, ISD::FSIN etc) so I followed the idiom used by ExpandFPLibCall.

RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) {
RTLIB::Libcall LC = RTLIB::getFPLibCall(
Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64,
Call_F80, Call_F128, Call_PPCF128);

if (LC == RTLIB::UNKNOWN_LIBCALL)
return false;

return tryExpandVecMathCall(Node, LC, Results);
}

void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
EVT VT = Node->getValueType(0);
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/CodeGen/TargetPassConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ static cl::opt<bool> MISchedPostRA(
static cl::opt<bool> EarlyLiveIntervals("early-live-intervals", cl::Hidden,
cl::desc("Run live interval analysis earlier in the pipeline"));

static cl::opt<bool> DisableReplaceWithVecLib(
"disable-replace-with-vec-lib", cl::Hidden,
cl::desc("Disable replace with vector math call pass"));

/// Option names for limiting the codegen pipeline.
/// Those are used in error reporting and we didn't want
/// to duplicate their names all over the place.
Expand Down Expand Up @@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() {
if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting)
addPass(createConstantHoistingPass());

if (getOptLevel() != CodeGenOptLevel::None)
if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib)
addPass(createReplaceWithVeclibLegacyPass());

if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining)
Expand Down
116 changes: 116 additions & 0 deletions llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s
; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s

target triple = "aarch64-unknown-linux-gnu"

define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 {
; ARMPL-LABEL: frem_v2f64:
; ARMPL: // %bb.0:
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; ARMPL-NEXT: .cfi_def_cfa_offset 16
; ARMPL-NEXT: .cfi_offset w30, -16
; ARMPL-NEXT: mov v0.16b, v1.16b
; ARMPL-NEXT: mov v1.16b, v2.16b
; ARMPL-NEXT: bl armpl_vfmodq_f64
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; ARMPL-NEXT: ret
;
; SLEEF-LABEL: frem_v2f64:
; SLEEF: // %bb.0:
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; SLEEF-NEXT: .cfi_def_cfa_offset 16
; SLEEF-NEXT: .cfi_offset w30, -16
; SLEEF-NEXT: mov v0.16b, v1.16b
; SLEEF-NEXT: mov v1.16b, v2.16b
; SLEEF-NEXT: bl _ZGVnN2vv_fmod
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; SLEEF-NEXT: ret
%res = frem <2 x double> %a, %b
ret <2 x double> %res
}

define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 {
; ARMPL-LABEL: frem_strict_v4f32:
; ARMPL: // %bb.0:
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; ARMPL-NEXT: .cfi_def_cfa_offset 16
; ARMPL-NEXT: .cfi_offset w30, -16
; ARMPL-NEXT: mov v0.16b, v1.16b
; ARMPL-NEXT: mov v1.16b, v2.16b
; ARMPL-NEXT: bl armpl_vfmodq_f32
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; ARMPL-NEXT: ret
;
; SLEEF-LABEL: frem_strict_v4f32:
; SLEEF: // %bb.0:
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; SLEEF-NEXT: .cfi_def_cfa_offset 16
; SLEEF-NEXT: .cfi_offset w30, -16
; SLEEF-NEXT: mov v0.16b, v1.16b
; SLEEF-NEXT: mov v1.16b, v2.16b
; SLEEF-NEXT: bl _ZGVnN4vv_fmodf
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; SLEEF-NEXT: ret
%res = frem <4 x float> %a, %b
ret <4 x float> %res
}

define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
; ARMPL-LABEL: frem_nxv4f32:
; ARMPL: // %bb.0:
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; ARMPL-NEXT: .cfi_def_cfa_offset 16
; ARMPL-NEXT: .cfi_offset w30, -16
; ARMPL-NEXT: ptrue p0.s
; ARMPL-NEXT: mov z0.d, z1.d
; ARMPL-NEXT: mov z1.d, z2.d
; ARMPL-NEXT: bl armpl_svfmod_f32_x
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; ARMPL-NEXT: ret
;
; SLEEF-LABEL: frem_nxv4f32:
; SLEEF: // %bb.0:
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; SLEEF-NEXT: .cfi_def_cfa_offset 16
; SLEEF-NEXT: .cfi_offset w30, -16
; SLEEF-NEXT: ptrue p0.s
; SLEEF-NEXT: mov z0.d, z1.d
; SLEEF-NEXT: mov z1.d, z2.d
; SLEEF-NEXT: bl _ZGVsMxvv_fmodf
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; SLEEF-NEXT: ret
%res = frem <vscale x 4 x float> %a, %b
ret <vscale x 4 x float> %res
}

define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 {
; ARMPL-LABEL: frem_strict_nxv2f64:
; ARMPL: // %bb.0:
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; ARMPL-NEXT: .cfi_def_cfa_offset 16
; ARMPL-NEXT: .cfi_offset w30, -16
; ARMPL-NEXT: ptrue p0.d
; ARMPL-NEXT: mov z0.d, z1.d
; ARMPL-NEXT: mov z1.d, z2.d
; ARMPL-NEXT: bl armpl_svfmod_f64_x
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; ARMPL-NEXT: ret
;
; SLEEF-LABEL: frem_strict_nxv2f64:
; SLEEF: // %bb.0:
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; SLEEF-NEXT: .cfi_def_cfa_offset 16
; SLEEF-NEXT: .cfi_offset w30, -16
; SLEEF-NEXT: ptrue p0.d
; SLEEF-NEXT: mov z0.d, z1.d
; SLEEF-NEXT: mov z1.d, z2.d
; SLEEF-NEXT: bl _ZGVsMxvv_fmod
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; SLEEF-NEXT: ret
%res = frem <vscale x 2 x double> %a, %b
ret <vscale x 2 x double> %res
}

attributes #0 = { "target-features"="+sve" }
attributes #1 = { "target-features"="+sve" strictfp }