Skip to content
89 changes: 89 additions & 0 deletions llvm/lib/Target/DirectX/DXILLegalizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,67 @@ downcastI64toI32InsertExtractElements(Instruction &I,
}
}

static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
ConstantInt *Length) {

uint64_t ByteLength = Length->getZExtValue();
// If length to copy is zero, no memcpy is needed.
if (ByteLength == 0)
return;

LLVMContext &Ctx = Builder.getContext();
const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();

auto GetArrTyFromVal = [](Value *Val) -> ArrayType * {
assert(isa<AllocaInst>(Val) ||
isa<GlobalVariable>(Val) &&
"Expected Val to be an Alloca or Global Variable");
if (auto *Alloca = dyn_cast<AllocaInst>(Val))
return dyn_cast<ArrayType>(Alloca->getAllocatedType());
if (auto *GlobalVar = dyn_cast<GlobalVariable>(Val))
return dyn_cast<ArrayType>(GlobalVar->getValueType());
return nullptr;
};

ArrayType *DstArrTy = GetArrTyFromVal(Dst);
assert(DstArrTy && "Expected Dst of memcpy to be a Pointer to an Array Type");
if (auto *DstGlobalVar = dyn_cast<GlobalVariable>(Dst))
assert(!DstGlobalVar->isConstant() &&
"The Dst of memcpy must not be a constant Global Variable");
[[maybe_unused]] ArrayType *SrcArrTy = GetArrTyFromVal(Src);
assert(SrcArrTy && "Expected Src of memcpy to be a Pointer to an Array Type");

Type *DstElemTy = DstArrTy->getElementType();
uint64_t DstElemByteSize = DL.getTypeStoreSize(DstElemTy);
assert(DstElemByteSize > 0 && "Dst element type store size must be set");
Type *SrcElemTy = SrcArrTy->getElementType();
[[maybe_unused]] uint64_t SrcElemByteSize = DL.getTypeStoreSize(SrcElemTy);
assert(SrcElemByteSize > 0 && "Src element type store size must be set");

// This assumption simplifies implementation and covers currently-known
// use-cases for DXIL. It may be relaxed in the future if required.
assert(DstElemTy == SrcElemTy &&
"The element types of Src and Dst arrays must match");

[[maybe_unused]] uint64_t DstArrNumElems = DstArrTy->getArrayNumElements();
assert(DstElemByteSize * DstArrNumElems >= ByteLength &&
"Dst array size must be at least as large as the memcpy length");
[[maybe_unused]] uint64_t SrcArrNumElems = SrcArrTy->getArrayNumElements();
assert(SrcElemByteSize * SrcArrNumElems >= ByteLength &&
"Src array size must be at least as large as the memcpy length");

uint64_t NumElemsToCopy = ByteLength / DstElemByteSize;
assert(ByteLength % DstElemByteSize == 0 &&
"memcpy length must be divisible by array element type");
for (uint64_t I = 0; I < NumElemsToCopy; ++I) {
Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
Value *SrcPtr = Builder.CreateInBoundsGEP(SrcElemTy, Src, Offset, "gep");
Value *SrcVal = Builder.CreateLoad(SrcElemTy, SrcPtr);
Value *DstPtr = Builder.CreateInBoundsGEP(DstElemTy, Dst, Offset, "gep");
Builder.CreateStore(SrcVal, DstPtr);
}
}

static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
ConstantInt *SizeCI,
DenseMap<Value *, Value *> &ReplacedValues) {
Expand Down Expand Up @@ -296,6 +357,33 @@ static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
}
}

// Expands the instruction `I` into corresponding loads and stores if it is a
// memcpy call. In that case, the call instruction is added to the `ToRemove`
// vector. `ReplacedValues` is unused.
static void legalizeMemCpy(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {

CallInst *CI = dyn_cast<CallInst>(&I);
if (!CI)
return;

Intrinsic::ID ID = CI->getIntrinsicID();
if (ID != Intrinsic::memcpy)
return;

IRBuilder<> Builder(&I);
Value *Dst = CI->getArgOperand(0);
Value *Src = CI->getArgOperand(1);
ConstantInt *Length = dyn_cast<ConstantInt>(CI->getArgOperand(2));
assert(Length && "Expected Length to be a ConstantInt");
ConstantInt *IsVolatile = dyn_cast<ConstantInt>(CI->getArgOperand(3));
assert(IsVolatile && "Expected IsVolatile to be a ConstantInt");
assert(IsVolatile->getZExtValue() == 0 && "Expected IsVolatile to be false");
emitMemcpyExpansion(Builder, Dst, Src, Length);
ToRemove.push_back(CI);
}

static void removeMemSet(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {
Expand Down Expand Up @@ -348,6 +436,7 @@ class DXILLegalizationPipeline {
LegalizationPipeline.push_back(fixI8UseChain);
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
LegalizationPipeline.push_back(legalizeFreeze);
LegalizationPipeline.push_back(legalizeMemCpy);
LegalizationPipeline.push_back(removeMemSet);
}
};
Expand Down
154 changes: 154 additions & 0 deletions llvm/test/CodeGen/DirectX/legalize-memcpy.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -dxil-legalize -dxil-finalize-linkage -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

define void @replace_int_memcpy_test() #0 {
; CHECK-LABEL: define void @replace_int_memcpy_test(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [1 x i32], align 4
; CHECK-NEXT: [[TMP2:%.*]] = alloca [1 x i32], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[GEP]], align 4
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP1]], align 4
; CHECK-NEXT: ret void
;
%1 = alloca [1 x i32], align 4
%2 = alloca [1 x i32], align 4
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(4) %2, ptr align 4 dereferenceable(4) %1, i32 4, i1 false)
ret void
}

define void @replace_3int_memcpy_test() #0 {
; CHECK-LABEL: define void @replace_3int_memcpy_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[TMP2:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[GEP]], align 4
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP1]], align 4
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr [[GEP2]], align 4
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 1
; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
; CHECK-NEXT: [[GEP4:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 2
; CHECK-NEXT: [[TMP5:%.*]] = load i32, ptr [[GEP4]], align 4
; CHECK-NEXT: [[GEP5:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 2
; CHECK-NEXT: store i32 [[TMP5]], ptr [[GEP5]], align 4
; CHECK-NEXT: ret void
;
%1 = alloca [3 x i32], align 4
%2 = alloca [3 x i32], align 4
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(12) %2, ptr align 4 dereferenceable(12) %1, i32 12, i1 false)
ret void
}

define void @replace_mismatched_size_int_memcpy_test() #0 {
; CHECK-LABEL: define void @replace_mismatched_size_int_memcpy_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x i32], align 4
; CHECK-NEXT: [[TMP2:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[GEP]], align 4
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP1]], align 4
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr [[GEP2]], align 4
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 1
; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
; CHECK-NEXT: ret void
;
%1 = alloca [2 x i32], align 4
%2 = alloca [3 x i32], align 4
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(12) %2, ptr align 4 dereferenceable(8) %1, i32 8, i1 false)
ret void
}

define void @replace_int16_memcpy_test() #0 {
; CHECK-LABEL: define void @replace_int16_memcpy_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x i16], align 2
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x i16], align 2
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i16, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[GEP]], align 2
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr [[TMP2]], i32 0
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, ptr [[TMP1]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = load i16, ptr [[GEP2]], align 2
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i16, ptr [[TMP2]], i32 1
; CHECK-NEXT: store i16 [[TMP4]], ptr [[GEP3]], align 2
; CHECK-NEXT: ret void
;
%1 = alloca [2 x i16], align 2
%2 = alloca [2 x i16], align 2
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 2 dereferenceable(4) %2, ptr align 2 dereferenceable(4) %1, i32 4, i1 false)
ret void
}

define void @replace_float_memcpy_test() #0 {
; CHECK-LABEL: define void @replace_float_memcpy_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x float], align 4
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x float], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds float, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[GEP]], align 4
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, ptr [[TMP2]], i32 0
; CHECK-NEXT: store float [[TMP3]], ptr [[GEP1]], align 4
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds float, ptr [[TMP1]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[GEP2]], align 4
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds float, ptr [[TMP2]], i32 1
; CHECK-NEXT: store float [[TMP4]], ptr [[GEP3]], align 4
; CHECK-NEXT: ret void
;
%1 = alloca [2 x float], align 4
%2 = alloca [2 x float], align 4
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(8) %2, ptr align 4 dereferenceable(8) %1, i32 8, i1 false)
ret void
}

define void @replace_double_memcpy_test() #0 {
; CHECK-LABEL: define void @replace_double_memcpy_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x double], align 4
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x double], align 4
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds double, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = load double, ptr [[GEP]], align 8
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds double, ptr [[TMP2]], i32 0
; CHECK-NEXT: store double [[TMP3]], ptr [[GEP1]], align 8
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds double, ptr [[TMP1]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = load double, ptr [[GEP2]], align 8
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds double, ptr [[TMP2]], i32 1
; CHECK-NEXT: store double [[TMP4]], ptr [[GEP3]], align 8
; CHECK-NEXT: ret void
;
%1 = alloca [2 x double], align 4
%2 = alloca [2 x double], align 4
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(8) %2, ptr align 4 dereferenceable(8) %1, i32 16, i1 false)
ret void
}

define void @replace_half_memcpy_test() #0 {
; CHECK-LABEL: define void @replace_half_memcpy_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x half], align 2
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x half], align 2
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds half, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = load half, ptr [[GEP]], align 2
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds half, ptr [[TMP2]], i32 0
; CHECK-NEXT: store half [[TMP3]], ptr [[GEP1]], align 2
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds half, ptr [[TMP1]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = load half, ptr [[GEP2]], align 2
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds half, ptr [[TMP2]], i32 1
; CHECK-NEXT: store half [[TMP4]], ptr [[GEP3]], align 2
; CHECK-NEXT: ret void
;
%1 = alloca [2 x half], align 2
%2 = alloca [2 x half], align 2
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 2 dereferenceable(4) %2, ptr align 2 dereferenceable(4) %1, i32 4, i1 false)
ret void
}

attributes #0 = {"hlsl.export"}

declare void @llvm.memcpy.p0.p2.i32(ptr noalias, ptr addrspace(2) noalias readonly, i32, i1)
declare void @llvm.memcpy.p0.p0.i32(ptr noalias, ptr noalias readonly, i32, i1)