Skip to content

Commit cd3f48d

Browse files
author
Greg Roth
authored
[NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller (#104626)
All expansions end with replacing the previous inrinsic with the new expansion and erasing the old one. By moving this operation to the caller, these expansion functions can be called in more contexts and a small amount of duplicated code is consolidated. Pre-req for #88056
1 parent 981191a commit cd3f48d

File tree

1 file changed

+61
-76
lines changed

1 file changed

+61
-76
lines changed

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 61 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
5151
return false;
5252
}
5353

54-
static bool expandAbs(CallInst *Orig) {
54+
static Value *expandAbs(CallInst *Orig) {
5555
Value *X = Orig->getOperand(0);
5656
IRBuilder<> Builder(Orig->getParent());
5757
Builder.SetInsertPoint(Orig);
@@ -64,14 +64,11 @@ static bool expandAbs(CallInst *Orig) {
6464
ConstantInt::get(EltTy, 0))
6565
: ConstantInt::get(EltTy, 0);
6666
auto *V = Builder.CreateSub(Zero, X);
67-
auto *MaxCall =
68-
Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
69-
Orig->replaceAllUsesWith(MaxCall);
70-
Orig->eraseFromParent();
71-
return true;
67+
return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
68+
"dx.max");
7269
}
7370

74-
static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
71+
static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
7572
assert(DotIntrinsic == Intrinsic::dx_sdot ||
7673
DotIntrinsic == Intrinsic::dx_udot);
7774
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
@@ -97,12 +94,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
9794
ArrayRef<Value *>{Elt0, Elt1, Result},
9895
nullptr, "dx.mad");
9996
}
100-
Orig->replaceAllUsesWith(Result);
101-
Orig->eraseFromParent();
102-
return true;
97+
return Result;
10398
}
10499

105-
static bool expandExpIntrinsic(CallInst *Orig) {
100+
static Value *expandExpIntrinsic(CallInst *Orig) {
106101
Value *X = Orig->getOperand(0);
107102
IRBuilder<> Builder(Orig->getParent());
108103
Builder.SetInsertPoint(Orig);
@@ -119,23 +114,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
119114
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
120115
Exp2Call->setTailCall(Orig->isTailCall());
121116
Exp2Call->setAttributes(Orig->getAttributes());
122-
Orig->replaceAllUsesWith(Exp2Call);
123-
Orig->eraseFromParent();
124-
return true;
117+
return Exp2Call;
125118
}
126119

127-
static bool expandAnyIntrinsic(CallInst *Orig) {
120+
static Value *expandAnyIntrinsic(CallInst *Orig) {
128121
Value *X = Orig->getOperand(0);
129122
IRBuilder<> Builder(Orig->getParent());
130123
Builder.SetInsertPoint(Orig);
131124
Type *Ty = X->getType();
132125
Type *EltTy = Ty->getScalarType();
133126

127+
Value *Result = nullptr;
134128
if (!Ty->isVectorTy()) {
135-
Value *Cond = EltTy->isFloatingPointTy()
136-
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
137-
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
138-
Orig->replaceAllUsesWith(Cond);
129+
Result = EltTy->isFloatingPointTy()
130+
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
131+
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
139132
} else {
140133
auto *XVec = dyn_cast<FixedVectorType>(Ty);
141134
Value *Cond =
@@ -148,18 +141,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
148141
X, ConstantVector::getSplat(
149142
ElementCount::getFixed(XVec->getNumElements()),
150143
ConstantInt::get(EltTy, 0)));
151-
Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
144+
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
152145
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
153146
Value *Elt = Builder.CreateExtractElement(Cond, I);
154147
Result = Builder.CreateOr(Result, Elt);
155148
}
156-
Orig->replaceAllUsesWith(Result);
157149
}
158-
Orig->eraseFromParent();
159-
return true;
150+
return Result;
160151
}
161152

162-
static bool expandLengthIntrinsic(CallInst *Orig) {
153+
static Value *expandLengthIntrinsic(CallInst *Orig) {
163154
Value *X = Orig->getOperand(0);
164155
IRBuilder<> Builder(Orig->getParent());
165156
Builder.SetInsertPoint(Orig);
@@ -182,30 +173,23 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
182173
Value *Mul = Builder.CreateFMul(Elt, Elt);
183174
Sum = Builder.CreateFAdd(Sum, Mul);
184175
}
185-
Value *Result = Builder.CreateIntrinsic(
186-
EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
187-
188-
Orig->replaceAllUsesWith(Result);
189-
Orig->eraseFromParent();
190-
return true;
176+
return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
177+
nullptr, "elt.sqrt");
191178
}
192179

193-
static bool expandLerpIntrinsic(CallInst *Orig) {
180+
static Value *expandLerpIntrinsic(CallInst *Orig) {
194181
Value *X = Orig->getOperand(0);
195182
Value *Y = Orig->getOperand(1);
196183
Value *S = Orig->getOperand(2);
197184
IRBuilder<> Builder(Orig->getParent());
198185
Builder.SetInsertPoint(Orig);
199186
auto *V = Builder.CreateFSub(Y, X);
200187
V = Builder.CreateFMul(S, V);
201-
auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
202-
Orig->replaceAllUsesWith(Result);
203-
Orig->eraseFromParent();
204-
return true;
188+
return Builder.CreateFAdd(X, V, "dx.lerp");
205189
}
206190

207-
static bool expandLogIntrinsic(CallInst *Orig,
208-
float LogConstVal = numbers::ln2f) {
191+
static Value *expandLogIntrinsic(CallInst *Orig,
192+
float LogConstVal = numbers::ln2f) {
209193
Value *X = Orig->getOperand(0);
210194
IRBuilder<> Builder(Orig->getParent());
211195
Builder.SetInsertPoint(Orig);
@@ -221,16 +205,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
221205
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
222206
Log2Call->setTailCall(Orig->isTailCall());
223207
Log2Call->setAttributes(Orig->getAttributes());
224-
auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
225-
Orig->replaceAllUsesWith(Result);
226-
Orig->eraseFromParent();
227-
return true;
208+
return Builder.CreateFMul(Ln2Const, Log2Call);
228209
}
229-
static bool expandLog10Intrinsic(CallInst *Orig) {
210+
static Value *expandLog10Intrinsic(CallInst *Orig) {
230211
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
231212
}
232213

233-
static bool expandNormalizeIntrinsic(CallInst *Orig) {
214+
static Value *expandNormalizeIntrinsic(CallInst *Orig) {
234215
Value *X = Orig->getOperand(0);
235216
Type *Ty = Orig->getType();
236217
Type *EltTy = Ty->getScalarType();
@@ -245,11 +226,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
245226
report_fatal_error(Twine("Invalid input scalar: length is zero"),
246227
/* gen_crash_diag=*/false);
247228
}
248-
Value *Result = Builder.CreateFDiv(X, X);
249-
250-
Orig->replaceAllUsesWith(Result);
251-
Orig->eraseFromParent();
252-
return true;
229+
return Builder.CreateFDiv(X, X);
253230
}
254231

255232
unsigned XVecSize = XVec->getNumElements();
@@ -291,14 +268,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
291268
nullptr, "dx.rsqrt");
292269

293270
Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
294-
Value *Result = Builder.CreateFMul(X, MultiplicandVec);
295-
296-
Orig->replaceAllUsesWith(Result);
297-
Orig->eraseFromParent();
298-
return true;
271+
return Builder.CreateFMul(X, MultiplicandVec);
299272
}
300273

301-
static bool expandPowIntrinsic(CallInst *Orig) {
274+
static Value *expandPowIntrinsic(CallInst *Orig) {
302275

303276
Value *X = Orig->getOperand(0);
304277
Value *Y = Orig->getOperand(1);
@@ -313,9 +286,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
313286
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
314287
Exp2Call->setTailCall(Orig->isTailCall());
315288
Exp2Call->setAttributes(Orig->getAttributes());
316-
Orig->replaceAllUsesWith(Exp2Call);
317-
Orig->eraseFromParent();
318-
return true;
289+
return Exp2Call;
319290
}
320291

321292
static Intrinsic::ID getMaxForClamp(Type *ElemTy,
@@ -344,7 +315,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
344315
return Intrinsic::minnum;
345316
}
346317

347-
static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
318+
static Value *expandClampIntrinsic(CallInst *Orig,
319+
Intrinsic::ID ClampIntrinsic) {
348320
Value *X = Orig->getOperand(0);
349321
Value *Min = Orig->getOperand(1);
350322
Value *Max = Orig->getOperand(2);
@@ -353,41 +325,54 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
353325
Builder.SetInsertPoint(Orig);
354326
auto *MaxCall = Builder.CreateIntrinsic(
355327
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
356-
auto *MinCall =
357-
Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
358-
{MaxCall, Max}, nullptr, "dx.min");
359-
360-
Orig->replaceAllUsesWith(MinCall);
361-
Orig->eraseFromParent();
362-
return true;
328+
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
329+
{MaxCall, Max}, nullptr, "dx.min");
363330
}
364331

365332
static bool expandIntrinsic(Function &F, CallInst *Orig) {
333+
Value *Result = nullptr;
366334
switch (F.getIntrinsicID()) {
367335
case Intrinsic::abs:
368-
return expandAbs(Orig);
336+
Result = expandAbs(Orig);
337+
break;
369338
case Intrinsic::exp:
370-
return expandExpIntrinsic(Orig);
339+
Result = expandExpIntrinsic(Orig);
340+
break;
371341
case Intrinsic::log:
372-
return expandLogIntrinsic(Orig);
342+
Result = expandLogIntrinsic(Orig);
343+
break;
373344
case Intrinsic::log10:
374-
return expandLog10Intrinsic(Orig);
345+
Result = expandLog10Intrinsic(Orig);
346+
break;
375347
case Intrinsic::pow:
376-
return expandPowIntrinsic(Orig);
348+
Result = expandPowIntrinsic(Orig);
349+
break;
377350
case Intrinsic::dx_any:
378-
return expandAnyIntrinsic(Orig);
351+
Result = expandAnyIntrinsic(Orig);
352+
break;
379353
case Intrinsic::dx_uclamp:
380354
case Intrinsic::dx_clamp:
381-
return expandClampIntrinsic(Orig, F.getIntrinsicID());
355+
Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
356+
break;
382357
case Intrinsic::dx_lerp:
383-
return expandLerpIntrinsic(Orig);
358+
Result = expandLerpIntrinsic(Orig);
359+
break;
384360
case Intrinsic::dx_length:
385-
return expandLengthIntrinsic(Orig);
361+
Result = expandLengthIntrinsic(Orig);
362+
break;
386363
case Intrinsic::dx_normalize:
387-
return expandNormalizeIntrinsic(Orig);
364+
Result = expandNormalizeIntrinsic(Orig);
365+
break;
388366
case Intrinsic::dx_sdot:
389367
case Intrinsic::dx_udot:
390-
return expandIntegerDot(Orig, F.getIntrinsicID());
368+
Result = expandIntegerDot(Orig, F.getIntrinsicID());
369+
break;
370+
}
371+
372+
if (Result) {
373+
Orig->replaceAllUsesWith(Result);
374+
Orig->eraseFromParent();
375+
return true;
391376
}
392377
return false;
393378
}

0 commit comments

Comments
 (0)