@@ -82,7 +82,7 @@ class FRemExpander {
8282 }
8383
8484 static FRemExpander create (IRBuilder<> &B, Type *Ty) {
85- assert (canExpandType (Ty));
85+ assert (canExpandType (Ty) && " Expected supported floating point type " );
8686
8787 // The type to use for the computation of the remainder. This may be
8888 // wider than the input/result type which affects the ...
@@ -356,8 +356,9 @@ Value *FRemExpander::buildFRem(Value *X, Value *Y,
356356static bool expandFRem (BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
357357 LLVM_DEBUG (dbgs () << " Expanding instruction: " << I << ' \n ' );
358358
359- Type *ReturnTy = I.getType ();
360- assert (FRemExpander::canExpandType (ReturnTy->getScalarType ()));
359+ Type *Ty = I.getType ();
360+ assert (FRemExpander::canExpandType (Ty) &&
361+ " Expected supported floating point type" );
361362
362363 FastMathFlags FMF = I.getFastMathFlags ();
363364 // TODO Make use of those flags for optimization?
@@ -368,32 +369,10 @@ static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
368369 B.setFastMathFlags (FMF);
369370 B.SetCurrentDebugLocation (I.getDebugLoc ());
370371
371- Type *ElemTy = ReturnTy->getScalarType ();
372- const FRemExpander Expander = FRemExpander::create (B, ElemTy);
373-
374- Value *Ret;
375- if (ReturnTy->isFloatingPointTy ())
376- Ret = FMF.approxFunc ()
377- ? Expander.buildApproxFRem (I.getOperand (0 ), I.getOperand (1 ))
378- : Expander.buildFRem (I.getOperand (0 ), I.getOperand (1 ), SQ);
379- else {
380- auto *VecTy = cast<FixedVectorType>(ReturnTy);
381-
382- // This could use SplitBlockAndInsertForEachLane but the interface
383- // is a bit awkward for a constant number of elements and it will
384- // boil down to the same code.
385- // TODO Expand the FRem instruction only once and reuse the code.
386- Value *Nums = I.getOperand (0 );
387- Value *Denums = I.getOperand (1 );
388- Ret = PoisonValue::get (I.getType ());
389- for (int I = 0 , E = VecTy->getNumElements (); I != E; ++I) {
390- Value *Num = B.CreateExtractElement (Nums, I);
391- Value *Denum = B.CreateExtractElement (Denums, I);
392- Value *Rem = FMF.approxFunc () ? Expander.buildApproxFRem (Num, Denum)
393- : Expander.buildFRem (Num, Denum, SQ);
394- Ret = B.CreateInsertElement (Ret, Rem, I);
395- }
396- }
372+ const FRemExpander Expander = FRemExpander::create (B, Ty);
373+ Value *Ret = FMF.approxFunc ()
374+ ? Expander.buildApproxFRem (I.getOperand (0 ), I.getOperand (1 ))
375+ : Expander.buildFRem (I.getOperand (0 ), I.getOperand (1 ), SQ);
397376
398377 I.replaceAllUsesWith (Ret);
399378 Ret->takeName (&I);
@@ -939,7 +918,8 @@ static void expandIToFP(Instruction *IToFP) {
939918 IToFP->eraseFromParent ();
940919}
941920
942- static void scalarize (Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
921+ static void scalarize (Instruction *I,
922+ SmallVectorImpl<Instruction *> &Worklist) {
943923 VectorType *VTy = cast<FixedVectorType>(I->getType ());
944924
945925 IRBuilder<> Builder (I);
@@ -948,12 +928,25 @@ static void scalarize(Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
948928 Value *Result = PoisonValue::get (VTy);
949929 for (unsigned Idx = 0 ; Idx < NumElements; ++Idx) {
950930 Value *Ext = Builder.CreateExtractElement (I->getOperand (0 ), Idx);
951- Value *Cast = Builder.CreateCast (cast<CastInst>(I)->getOpcode (), Ext,
952- I->getType ()->getScalarType ());
953- Result = Builder.CreateInsertElement (Result, Cast, Idx);
954- if (isa<Instruction>(Cast))
955- Replace.push_back (cast<Instruction>(Cast));
931+
932+ Value *NewOp = nullptr ;
933+ if (auto *BinOp = dyn_cast<BinaryOperator>(I))
934+ NewOp = Builder.CreateBinOp (
935+ BinOp->getOpcode (), Ext,
936+ Builder.CreateExtractElement (I->getOperand (1 ), Idx));
937+ else if (auto *CastI = dyn_cast<CastInst>(I))
938+ NewOp = Builder.CreateCast (CastI->getOpcode (), Ext,
939+ I->getType ()->getScalarType ());
940+ else
941+ llvm_unreachable (" Unsupported instruction type" );
942+
943+ Result = Builder.CreateInsertElement (Result, NewOp, Idx);
944+ if (auto *ScalarizedI = dyn_cast<Instruction>(NewOp)) {
945+ ScalarizedI->copyIRFlags (I, true );
946+ Worklist.push_back (ScalarizedI);
947+ }
956948 }
949+
957950 I->replaceAllUsesWith (Result);
958951 I->dropAllReferences ();
959952 I->eraseFromParent ();
@@ -989,10 +982,17 @@ static bool targetSupportsFrem(const TargetLowering &TLI, Type *Ty) {
989982 return TLI.getLibcallName (fremToLibcall (Ty->getScalarType ()));
990983}
991984
985+ static void addToWorklist (Instruction &I,
986+ SmallVector<Instruction *, 4 > &Worklist) {
987+ if (I.getOperand (0 )->getType ()->isVectorTy ())
988+ scalarize (&I, Worklist);
989+ else
990+ Worklist.push_back (&I);
991+ }
992+
992993static bool runImpl (Function &F, const TargetLowering &TLI,
993994 AssumptionCache *AC) {
994- SmallVector<Instruction *, 4 > Replace;
995- SmallVector<Instruction *, 4 > ReplaceVector;
995+ SmallVector<Instruction *, 4 > Worklist;
996996 bool Modified = false ;
997997
998998 unsigned MaxLegalFpConvertBitWidth =
@@ -1003,73 +1003,48 @@ static bool runImpl(Function &F, const TargetLowering &TLI,
10031003 if (MaxLegalFpConvertBitWidth >= llvm::IntegerType::MAX_INT_BITS)
10041004 return false ;
10051005
1006- for (auto &I : instructions (F)) {
1007- switch (I.getOpcode ()) {
1008- case Instruction::FRem: {
1009- Type *Ty = I.getType ();
1010- // TODO: This pass doesn't handle scalable vectors.
1011- if (Ty->isScalableTy ())
1012- continue ;
1013-
1014- if (targetSupportsFrem (TLI, Ty) ||
1015- !FRemExpander::canExpandType (Ty->getScalarType ()))
1016- continue ;
1017-
1018- Replace.push_back (&I);
1019- Modified = true ;
1006+ for (auto It = inst_begin (&F), End = inst_end (F); It != End;) {
1007+ Instruction &I = *It++;
1008+ Type *Ty = I.getType ();
1009+ // TODO: This pass doesn't handle scalable vectors.
1010+ if (Ty->isScalableTy ())
1011+ continue ;
10201012
1013+ switch (I.getOpcode ()) {
1014+ case Instruction::FRem:
1015+ if (!targetSupportsFrem (TLI, Ty) &&
1016+ FRemExpander::canExpandType (Ty->getScalarType ())) {
1017+ addToWorklist (I, Worklist);
1018+ Modified = true ;
1019+ }
10211020 break ;
1022- }
10231021 case Instruction::FPToUI:
10241022 case Instruction::FPToSI: {
1025- // TODO: This pass doesn't handle scalable vectors.
1026- if (I.getOperand (0 )->getType ()->isScalableTy ())
1027- continue ;
1028-
1029- auto *IntTy = cast<IntegerType>(I.getType ()->getScalarType ());
1023+ auto *IntTy = cast<IntegerType>(Ty->getScalarType ());
10301024 if (IntTy->getIntegerBitWidth () <= MaxLegalFpConvertBitWidth)
10311025 continue ;
10321026
1033- if (I.getOperand (0 )->getType ()->isVectorTy ())
1034- ReplaceVector.push_back (&I);
1035- else
1036- Replace.push_back (&I);
1027+ addToWorklist (I, Worklist);
10371028 Modified = true ;
10381029 break ;
10391030 }
10401031 case Instruction::UIToFP:
10411032 case Instruction::SIToFP: {
1042- // TODO: This pass doesn't handle scalable vectors.
1043- if (I.getOperand (0 )->getType ()->isScalableTy ())
1044- continue ;
1045-
10461033 auto *IntTy =
10471034 cast<IntegerType>(I.getOperand (0 )->getType ()->getScalarType ());
10481035 if (IntTy->getIntegerBitWidth () <= MaxLegalFpConvertBitWidth)
10491036 continue ;
10501037
1051- if (I.getOperand (0 )->getType ()->isVectorTy ())
1052- ReplaceVector.push_back (&I);
1053- else
1054- Replace.push_back (&I);
1055- Modified = true ;
1038+ addToWorklist (I, Worklist);
10561039 break ;
10571040 }
10581041 default :
10591042 break ;
10601043 }
10611044 }
10621045
1063- while (!ReplaceVector.empty ()) {
1064- Instruction *I = ReplaceVector.pop_back_val ();
1065- scalarize (I, Replace);
1066- }
1067-
1068- if (Replace.empty ())
1069- return false ;
1070-
1071- while (!Replace.empty ()) {
1072- Instruction *I = Replace.pop_back_val ();
1046+ while (!Worklist.empty ()) {
1047+ Instruction *I = Worklist.pop_back_val ();
10731048 if (I->getOpcode () == Instruction::FRem) {
10741049 auto SQ = [&]() -> std::optional<SimplifyQuery> {
10751050 if (AC) {
0 commit comments