@@ -181,10 +181,17 @@ static void upcastI8AllocasAndUses(Instruction &I,
181
181
if (!Load)
182
182
continue ;
183
183
for (User *LU : Load->users ()) {
184
- auto *Cast = dyn_cast<CastInst>(LU);
185
- if (!Cast)
184
+ Type *Ty = nullptr ;
185
+ if (auto *Cast = dyn_cast<CastInst>(LU))
186
+ Ty = Cast->getType ();
187
+ if (CallInst *CI = dyn_cast<CallInst>(LU)) {
188
+ if (CI->getIntrinsicID () == Intrinsic::memset )
189
+ Ty = Type::getInt32Ty (CI->getContext ());
190
+ }
191
+
192
+ if (!Ty)
186
193
continue ;
187
- Type *Ty = Cast-> getType ();
194
+
188
195
if (!SmallestType ||
189
196
Ty->getPrimitiveSizeInBits () < SmallestType->getPrimitiveSizeInBits ())
190
197
SmallestType = Ty;
@@ -240,8 +247,9 @@ downcastI64toI32InsertExtractElements(Instruction &I,
240
247
}
241
248
}
242
249
243
- void emitMemset (IRBuilder<> &Builder, Value *Dst, Value *Val,
244
- ConstantInt *SizeCI) {
250
+ void emitMemsetExpansion (IRBuilder<> &Builder, Value *Dst, Value *Val,
251
+ ConstantInt *SizeCI,
252
+ DenseMap<Value *, Value *> &ReplacedValues) {
245
253
LLVMContext &Ctx = Builder.getContext ();
246
254
[[maybe_unused]] DataLayout DL =
247
255
Builder.GetInsertBlock ()->getModule ()->getDataLayout ();
@@ -266,9 +274,19 @@ void emitMemset(IRBuilder<> &Builder, Value *Dst, Value *Val,
266
274
assert (OrigSize == ElemSize * Size && " Size in bytes must match" );
267
275
268
276
Value *TypedVal = Val;
269
- if (Val->getType () != ElemTy)
270
- TypedVal = Builder.CreateIntCast (Val, ElemTy,
271
- false ); // Or use CreateBitCast for float
277
+
278
+ if (Val->getType () != ElemTy) {
279
+ // Note for i8 replacements if we know them we should use them.
280
+ // Further if this is a constant ReplacedValues will return null
281
+ // so we will stick to TypedVal = Val
282
+ if (ReplacedValues[Val])
283
+ TypedVal = ReplacedValues[Val];
284
+ // This case Val is a ConstantInt so the cast folds away.
285
+ // However if we don't do the cast the store below ends up being
286
+ // an i8.
287
+ else
288
+ TypedVal = Builder.CreateIntCast (Val, ElemTy, false );
289
+ }
272
290
273
291
for (uint64_t I = 0 ; I < Size ; ++I) {
274
292
Value *Offset = ConstantInt::get (Type::getInt32Ty (Ctx), I);
@@ -279,7 +297,7 @@ void emitMemset(IRBuilder<> &Builder, Value *Dst, Value *Val,
279
297
280
298
static void removeMemSet (Instruction &I,
281
299
SmallVectorImpl<Instruction *> &ToRemove,
282
- DenseMap<Value *, Value *>) {
300
+ DenseMap<Value *, Value *> &ReplacedValues ) {
283
301
if (CallInst *CI = dyn_cast<CallInst>(&I)) {
284
302
Intrinsic::ID ID = CI->getIntrinsicID ();
285
303
if (ID == Intrinsic::memset ) {
@@ -289,7 +307,7 @@ static void removeMemSet(Instruction &I,
289
307
[[maybe_unused]] ConstantInt *Size =
290
308
dyn_cast<ConstantInt>(CI->getArgOperand (2 ));
291
309
assert (Size && " Expected Size to be a ConstantInt" );
292
- emitMemset (Builder, Dst, Val, Size );
310
+ emitMemsetExpansion (Builder, Dst, Val, Size , ReplacedValues );
293
311
ToRemove.push_back (CI);
294
312
}
295
313
}
@@ -322,11 +340,11 @@ class DXILLegalizationPipeline {
322
340
LegalizationPipeline;
323
341
324
342
void initializeLegalizationPipeline () {
325
- LegalizationPipeline.push_back (removeMemSet);
326
343
LegalizationPipeline.push_back (upcastI8AllocasAndUses);
327
344
LegalizationPipeline.push_back (fixI8UseChain);
328
345
LegalizationPipeline.push_back (downcastI64toI32InsertExtractElements);
329
346
LegalizationPipeline.push_back (legalizeFreeze);
347
+ LegalizationPipeline.push_back (removeMemSet);
330
348
}
331
349
};
332
350
0 commit comments