Skip to content

Commit 0e45322

Browse files
frederik-hDharuniRAcharya
authored andcommitted
[AMDGPU] expand-fp: unify scalarization (NFC) (llvm#158588)
Extend the existing "scalarize" function which is used for the fp-integer conversion instruction expansion to BinaryOperator instructions and reuse it for the frem expansion; a similar function for scalarizing BinaryOperator instructions exists in the ExpandLargeDivRem pass and this change is a step towards merging that pass with ExpandFp. Further refactoring: Scalarize directly instead of using the "ReplaceVector" as a worklist, rename "Replace" vector to "Worklist", and hoist a check for unsupported scalable vectors to the top of the instruction visiting loop.
1 parent fa409a2 commit 0e45322

File tree

4 files changed

+985
-1010
lines changed

4 files changed

+985
-1010
lines changed

llvm/lib/CodeGen/ExpandFp.cpp

Lines changed: 55 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class FRemExpander {
8282
}
8383

8484
static FRemExpander create(IRBuilder<> &B, Type *Ty) {
85-
assert(canExpandType(Ty));
85+
assert(canExpandType(Ty) && "Expected supported floating point type");
8686

8787
// The type to use for the computation of the remainder. This may be
8888
// wider than the input/result type which affects the ...
@@ -356,8 +356,9 @@ Value *FRemExpander::buildFRem(Value *X, Value *Y,
356356
static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
357357
LLVM_DEBUG(dbgs() << "Expanding instruction: " << I << '\n');
358358

359-
Type *ReturnTy = I.getType();
360-
assert(FRemExpander::canExpandType(ReturnTy->getScalarType()));
359+
Type *Ty = I.getType();
360+
assert(FRemExpander::canExpandType(Ty) &&
361+
"Expected supported floating point type");
361362

362363
FastMathFlags FMF = I.getFastMathFlags();
363364
// TODO Make use of those flags for optimization?
@@ -368,32 +369,10 @@ static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
368369
B.setFastMathFlags(FMF);
369370
B.SetCurrentDebugLocation(I.getDebugLoc());
370371

371-
Type *ElemTy = ReturnTy->getScalarType();
372-
const FRemExpander Expander = FRemExpander::create(B, ElemTy);
373-
374-
Value *Ret;
375-
if (ReturnTy->isFloatingPointTy())
376-
Ret = FMF.approxFunc()
377-
? Expander.buildApproxFRem(I.getOperand(0), I.getOperand(1))
378-
: Expander.buildFRem(I.getOperand(0), I.getOperand(1), SQ);
379-
else {
380-
auto *VecTy = cast<FixedVectorType>(ReturnTy);
381-
382-
// This could use SplitBlockAndInsertForEachLane but the interface
383-
// is a bit awkward for a constant number of elements and it will
384-
// boil down to the same code.
385-
// TODO Expand the FRem instruction only once and reuse the code.
386-
Value *Nums = I.getOperand(0);
387-
Value *Denums = I.getOperand(1);
388-
Ret = PoisonValue::get(I.getType());
389-
for (int I = 0, E = VecTy->getNumElements(); I != E; ++I) {
390-
Value *Num = B.CreateExtractElement(Nums, I);
391-
Value *Denum = B.CreateExtractElement(Denums, I);
392-
Value *Rem = FMF.approxFunc() ? Expander.buildApproxFRem(Num, Denum)
393-
: Expander.buildFRem(Num, Denum, SQ);
394-
Ret = B.CreateInsertElement(Ret, Rem, I);
395-
}
396-
}
372+
const FRemExpander Expander = FRemExpander::create(B, Ty);
373+
Value *Ret = FMF.approxFunc()
374+
? Expander.buildApproxFRem(I.getOperand(0), I.getOperand(1))
375+
: Expander.buildFRem(I.getOperand(0), I.getOperand(1), SQ);
397376

398377
I.replaceAllUsesWith(Ret);
399378
Ret->takeName(&I);
@@ -939,7 +918,8 @@ static void expandIToFP(Instruction *IToFP) {
939918
IToFP->eraseFromParent();
940919
}
941920

942-
static void scalarize(Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
921+
static void scalarize(Instruction *I,
922+
SmallVectorImpl<Instruction *> &Worklist) {
943923
VectorType *VTy = cast<FixedVectorType>(I->getType());
944924

945925
IRBuilder<> Builder(I);
@@ -948,12 +928,25 @@ static void scalarize(Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
948928
Value *Result = PoisonValue::get(VTy);
949929
for (unsigned Idx = 0; Idx < NumElements; ++Idx) {
950930
Value *Ext = Builder.CreateExtractElement(I->getOperand(0), Idx);
951-
Value *Cast = Builder.CreateCast(cast<CastInst>(I)->getOpcode(), Ext,
952-
I->getType()->getScalarType());
953-
Result = Builder.CreateInsertElement(Result, Cast, Idx);
954-
if (isa<Instruction>(Cast))
955-
Replace.push_back(cast<Instruction>(Cast));
931+
932+
Value *NewOp = nullptr;
933+
if (auto *BinOp = dyn_cast<BinaryOperator>(I))
934+
NewOp = Builder.CreateBinOp(
935+
BinOp->getOpcode(), Ext,
936+
Builder.CreateExtractElement(I->getOperand(1), Idx));
937+
else if (auto *CastI = dyn_cast<CastInst>(I))
938+
NewOp = Builder.CreateCast(CastI->getOpcode(), Ext,
939+
I->getType()->getScalarType());
940+
else
941+
llvm_unreachable("Unsupported instruction type");
942+
943+
Result = Builder.CreateInsertElement(Result, NewOp, Idx);
944+
if (auto *ScalarizedI = dyn_cast<Instruction>(NewOp)) {
945+
ScalarizedI->copyIRFlags(I, true);
946+
Worklist.push_back(ScalarizedI);
947+
}
956948
}
949+
957950
I->replaceAllUsesWith(Result);
958951
I->dropAllReferences();
959952
I->eraseFromParent();
@@ -989,10 +982,17 @@ static bool targetSupportsFrem(const TargetLowering &TLI, Type *Ty) {
989982
return TLI.getLibcallName(fremToLibcall(Ty->getScalarType()));
990983
}
991984

985+
static void addToWorklist(Instruction &I,
986+
SmallVector<Instruction *, 4> &Worklist) {
987+
if (I.getOperand(0)->getType()->isVectorTy())
988+
scalarize(&I, Worklist);
989+
else
990+
Worklist.push_back(&I);
991+
}
992+
992993
static bool runImpl(Function &F, const TargetLowering &TLI,
993994
AssumptionCache *AC) {
994-
SmallVector<Instruction *, 4> Replace;
995-
SmallVector<Instruction *, 4> ReplaceVector;
995+
SmallVector<Instruction *, 4> Worklist;
996996
bool Modified = false;
997997

998998
unsigned MaxLegalFpConvertBitWidth =
@@ -1003,73 +1003,48 @@ static bool runImpl(Function &F, const TargetLowering &TLI,
10031003
if (MaxLegalFpConvertBitWidth >= llvm::IntegerType::MAX_INT_BITS)
10041004
return false;
10051005

1006-
for (auto &I : instructions(F)) {
1007-
switch (I.getOpcode()) {
1008-
case Instruction::FRem: {
1009-
Type *Ty = I.getType();
1010-
// TODO: This pass doesn't handle scalable vectors.
1011-
if (Ty->isScalableTy())
1012-
continue;
1013-
1014-
if (targetSupportsFrem(TLI, Ty) ||
1015-
!FRemExpander::canExpandType(Ty->getScalarType()))
1016-
continue;
1017-
1018-
Replace.push_back(&I);
1019-
Modified = true;
1006+
for (auto It = inst_begin(&F), End = inst_end(F); It != End;) {
1007+
Instruction &I = *It++;
1008+
Type *Ty = I.getType();
1009+
// TODO: This pass doesn't handle scalable vectors.
1010+
if (Ty->isScalableTy())
1011+
continue;
10201012

1013+
switch (I.getOpcode()) {
1014+
case Instruction::FRem:
1015+
if (!targetSupportsFrem(TLI, Ty) &&
1016+
FRemExpander::canExpandType(Ty->getScalarType())) {
1017+
addToWorklist(I, Worklist);
1018+
Modified = true;
1019+
}
10211020
break;
1022-
}
10231021
case Instruction::FPToUI:
10241022
case Instruction::FPToSI: {
1025-
// TODO: This pass doesn't handle scalable vectors.
1026-
if (I.getOperand(0)->getType()->isScalableTy())
1027-
continue;
1028-
1029-
auto *IntTy = cast<IntegerType>(I.getType()->getScalarType());
1023+
auto *IntTy = cast<IntegerType>(Ty->getScalarType());
10301024
if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth)
10311025
continue;
10321026

1033-
if (I.getOperand(0)->getType()->isVectorTy())
1034-
ReplaceVector.push_back(&I);
1035-
else
1036-
Replace.push_back(&I);
1027+
addToWorklist(I, Worklist);
10371028
Modified = true;
10381029
break;
10391030
}
10401031
case Instruction::UIToFP:
10411032
case Instruction::SIToFP: {
1042-
// TODO: This pass doesn't handle scalable vectors.
1043-
if (I.getOperand(0)->getType()->isScalableTy())
1044-
continue;
1045-
10461033
auto *IntTy =
10471034
cast<IntegerType>(I.getOperand(0)->getType()->getScalarType());
10481035
if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth)
10491036
continue;
10501037

1051-
if (I.getOperand(0)->getType()->isVectorTy())
1052-
ReplaceVector.push_back(&I);
1053-
else
1054-
Replace.push_back(&I);
1055-
Modified = true;
1038+
addToWorklist(I, Worklist);
10561039
break;
10571040
}
10581041
default:
10591042
break;
10601043
}
10611044
}
10621045

1063-
while (!ReplaceVector.empty()) {
1064-
Instruction *I = ReplaceVector.pop_back_val();
1065-
scalarize(I, Replace);
1066-
}
1067-
1068-
if (Replace.empty())
1069-
return false;
1070-
1071-
while (!Replace.empty()) {
1072-
Instruction *I = Replace.pop_back_val();
1046+
while (!Worklist.empty()) {
1047+
Instruction *I = Worklist.pop_back_val();
10731048
if (I->getOpcode() == Instruction::FRem) {
10741049
auto SQ = [&]() -> std::optional<SimplifyQuery> {
10751050
if (AC) {

0 commit comments

Comments
 (0)