-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. #83859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,8 @@ | |
|
||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/Analysis/TargetLibraryInfo.h" | ||
#include "llvm/Analysis/VectorUtils.h" | ||
#include "llvm/CodeGen/ISDOpcodes.h" | ||
#include "llvm/CodeGen/SelectionDAG.h" | ||
#include "llvm/CodeGen/SelectionDAGNodes.h" | ||
|
@@ -147,6 +149,14 @@ class VectorLegalizer { | |
void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results); | ||
void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results); | ||
|
||
bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC, | ||
SmallVectorImpl<SDValue> &Results); | ||
bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32, | ||
RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80, | ||
RTLIB::Libcall Call_F128, | ||
RTLIB::Libcall Call_PPCF128, | ||
SmallVectorImpl<SDValue> &Results); | ||
|
||
void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results); | ||
|
||
/// Implements vector promotion. | ||
|
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) { | |
case ISD::VP_MERGE: | ||
Results.push_back(ExpandVP_MERGE(Node)); | ||
return; | ||
case ISD::FREM: | ||
if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64, | ||
RTLIB::REM_F80, RTLIB::REM_F128, | ||
RTLIB::REM_PPCF128, Results)) | ||
return; | ||
|
||
break; | ||
} | ||
|
||
SDValue Unrolled = DAG.UnrollVectorOp(Node); | ||
|
@@ -1842,6 +1859,117 @@ void VectorLegalizer::ExpandREM(SDNode *Node, | |
Results.push_back(Result); | ||
} | ||
|
||
// Try to expand libm nodes into vector math routine calls. Callers provide the | ||
// LibFunc equivalent of the passed in Node, which is used to lookup mappings | ||
// within TargetLibraryInfo. The only mappings considered are those where the | ||
// result and all operands are the same vector type. While predicated nodes are | ||
// not supported, we will emit calls to masked routines by passing in an all | ||
// true mask. | ||
bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC, | ||
SmallVectorImpl<SDValue> &Results) { | ||
// Chain must be propagated but currently strict fp operations are down | ||
// converted to their none strict counterpart. | ||
assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!"); | ||
|
||
const char *LCName = TLI.getLibcallName(LC); | ||
if (!LCName) | ||
return false; | ||
LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n"); | ||
|
||
EVT VT = Node->getValueType(0); | ||
ElementCount VL = VT.getVectorElementCount(); | ||
|
||
// Lookup a vector function equivalent to the specified libcall. Prefer | ||
// unmasked variants but we will generate a mask if need be. | ||
const TargetLibraryInfo &TLibInfo = DAG.getLibInfo(); | ||
const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false); | ||
if (!VD) | ||
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/true); | ||
if (!VD) | ||
return false; | ||
paschalis-mpeis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
LLVMContext *Ctx = DAG.getContext(); | ||
Type *Ty = VT.getTypeForEVT(*Ctx); | ||
Type *ScalarTy = Ty->getScalarType(); | ||
|
||
// Construct a scalar function type based on Node's operands. | ||
SmallVector<Type *, 8> ArgTys; | ||
for (unsigned i = 0; i < Node->getNumOperands(); ++i) { | ||
assert(Node->getOperand(i).getValueType() == VT && | ||
"Expected matching vector types!"); | ||
ArgTys.push_back(ScalarTy); | ||
} | ||
FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false); | ||
|
||
// Generate call information for the vector function. | ||
const std::string MangledName = VD->getVectorFunctionABIVariantString(); | ||
auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy); | ||
if (!OptVFInfo) | ||
return false; | ||
|
||
LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName() | ||
<< "\n"); | ||
|
||
// Sanity check just in case OptVFInfo has unexpected parameters. | ||
if (OptVFInfo->Shape.Parameters.size() != | ||
Node->getNumOperands() + VD->isMasked()) | ||
return false; | ||
|
||
// Collect vector call operands. | ||
|
||
SDLoc DL(Node); | ||
TargetLowering::ArgListTy Args; | ||
TargetLowering::ArgListEntry Entry; | ||
Entry.IsSExt = false; | ||
Entry.IsZExt = false; | ||
|
||
unsigned OpNum = 0; | ||
for (auto &VFParam : OptVFInfo->Shape.Parameters) { | ||
if (VFParam.ParamKind == VFParamKind::GlobalPredicate) { | ||
EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT); | ||
Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT); | ||
Entry.Ty = MaskVT.getTypeForEVT(*Ctx); | ||
Args.push_back(Entry); | ||
continue; | ||
} | ||
|
||
// Only vector operands are supported. | ||
if (VFParam.ParamKind != VFParamKind::Vector) | ||
return false; | ||
|
||
Entry.Node = Node->getOperand(OpNum++); | ||
Entry.Ty = Ty; | ||
Args.push_back(Entry); | ||
} | ||
|
||
// Emit a call to the vector function. | ||
SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(), | ||
TLI.getPointerTy(DAG.getDataLayout())); | ||
TargetLowering::CallLoweringInfo CLI(DAG); | ||
CLI.setDebugLoc(DL) | ||
.setChain(DAG.getEntryNode()) | ||
.setLibCallee(CallingConv::C, Ty, Callee, std::move(Args)); | ||
|
||
std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI); | ||
Results.push_back(CallResult.first); | ||
return true; | ||
} | ||
|
||
/// Try to expand the node to a vector libcall based on the result type. | ||
bool VectorLegalizer::tryExpandVecMathCall( | ||
SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason the Libcall variants are parameters to this function, instead of being hardcoded? (Looking at other functions like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm expecting this function to be used by other ISD nodes (e.g. |
||
RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128, | ||
RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) { | ||
RTLIB::Libcall LC = RTLIB::getFPLibCall( | ||
Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64, | ||
Call_F80, Call_F128, Call_PPCF128); | ||
|
||
if (LC == RTLIB::UNKNOWN_LIBCALL) | ||
return false; | ||
|
||
return tryExpandVecMathCall(Node, LC, Results); | ||
} | ||
|
||
void VectorLegalizer::UnrollStrictFPOp(SDNode *Node, | ||
SmallVectorImpl<SDValue> &Results) { | ||
EVT VT = Node->getValueType(0); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 | ||
; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s | ||
; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s | ||
|
||
target triple = "aarch64-unknown-linux-gnu" | ||
|
||
define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 { | ||
; ARMPL-LABEL: frem_v2f64: | ||
; ARMPL: // %bb.0: | ||
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; ARMPL-NEXT: .cfi_def_cfa_offset 16 | ||
; ARMPL-NEXT: .cfi_offset w30, -16 | ||
; ARMPL-NEXT: mov v0.16b, v1.16b | ||
; ARMPL-NEXT: mov v1.16b, v2.16b | ||
; ARMPL-NEXT: bl armpl_vfmodq_f64 | ||
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; ARMPL-NEXT: ret | ||
; | ||
; SLEEF-LABEL: frem_v2f64: | ||
; SLEEF: // %bb.0: | ||
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; SLEEF-NEXT: .cfi_def_cfa_offset 16 | ||
; SLEEF-NEXT: .cfi_offset w30, -16 | ||
; SLEEF-NEXT: mov v0.16b, v1.16b | ||
; SLEEF-NEXT: mov v1.16b, v2.16b | ||
; SLEEF-NEXT: bl _ZGVnN2vv_fmod | ||
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; SLEEF-NEXT: ret | ||
%res = frem <2 x double> %a, %b | ||
ret <2 x double> %res | ||
} | ||
|
||
define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 { | ||
; ARMPL-LABEL: frem_strict_v4f32: | ||
; ARMPL: // %bb.0: | ||
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; ARMPL-NEXT: .cfi_def_cfa_offset 16 | ||
; ARMPL-NEXT: .cfi_offset w30, -16 | ||
; ARMPL-NEXT: mov v0.16b, v1.16b | ||
; ARMPL-NEXT: mov v1.16b, v2.16b | ||
; ARMPL-NEXT: bl armpl_vfmodq_f32 | ||
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; ARMPL-NEXT: ret | ||
; | ||
; SLEEF-LABEL: frem_strict_v4f32: | ||
; SLEEF: // %bb.0: | ||
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; SLEEF-NEXT: .cfi_def_cfa_offset 16 | ||
; SLEEF-NEXT: .cfi_offset w30, -16 | ||
; SLEEF-NEXT: mov v0.16b, v1.16b | ||
; SLEEF-NEXT: mov v1.16b, v2.16b | ||
; SLEEF-NEXT: bl _ZGVnN4vv_fmodf | ||
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; SLEEF-NEXT: ret | ||
%res = frem <4 x float> %a, %b | ||
ret <4 x float> %res | ||
} | ||
|
||
define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 { | ||
; ARMPL-LABEL: frem_nxv4f32: | ||
; ARMPL: // %bb.0: | ||
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; ARMPL-NEXT: .cfi_def_cfa_offset 16 | ||
; ARMPL-NEXT: .cfi_offset w30, -16 | ||
; ARMPL-NEXT: ptrue p0.s | ||
; ARMPL-NEXT: mov z0.d, z1.d | ||
; ARMPL-NEXT: mov z1.d, z2.d | ||
; ARMPL-NEXT: bl armpl_svfmod_f32_x | ||
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; ARMPL-NEXT: ret | ||
; | ||
; SLEEF-LABEL: frem_nxv4f32: | ||
; SLEEF: // %bb.0: | ||
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; SLEEF-NEXT: .cfi_def_cfa_offset 16 | ||
; SLEEF-NEXT: .cfi_offset w30, -16 | ||
; SLEEF-NEXT: ptrue p0.s | ||
; SLEEF-NEXT: mov z0.d, z1.d | ||
; SLEEF-NEXT: mov z1.d, z2.d | ||
; SLEEF-NEXT: bl _ZGVsMxvv_fmodf | ||
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; SLEEF-NEXT: ret | ||
%res = frem <vscale x 4 x float> %a, %b | ||
ret <vscale x 4 x float> %res | ||
} | ||
|
||
define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 { | ||
; ARMPL-LABEL: frem_strict_nxv2f64: | ||
; ARMPL: // %bb.0: | ||
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; ARMPL-NEXT: .cfi_def_cfa_offset 16 | ||
; ARMPL-NEXT: .cfi_offset w30, -16 | ||
; ARMPL-NEXT: ptrue p0.d | ||
; ARMPL-NEXT: mov z0.d, z1.d | ||
; ARMPL-NEXT: mov z1.d, z2.d | ||
; ARMPL-NEXT: bl armpl_svfmod_f64_x | ||
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; ARMPL-NEXT: ret | ||
; | ||
; SLEEF-LABEL: frem_strict_nxv2f64: | ||
; SLEEF: // %bb.0: | ||
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill | ||
; SLEEF-NEXT: .cfi_def_cfa_offset 16 | ||
; SLEEF-NEXT: .cfi_offset w30, -16 | ||
; SLEEF-NEXT: ptrue p0.d | ||
; SLEEF-NEXT: mov z0.d, z1.d | ||
; SLEEF-NEXT: mov z1.d, z2.d | ||
; SLEEF-NEXT: bl _ZGVsMxvv_fmod | ||
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload | ||
; SLEEF-NEXT: ret | ||
%res = frem <vscale x 2 x double> %a, %b | ||
ret <vscale x 2 x double> %res | ||
} | ||
|
||
attributes #0 = { "target-features"="+sve" } | ||
attributes #1 = { "target-features"="+sve" strictfp } |
Uh oh!
There was an error while loading. Please reload this page.