Skip to content

[ConstantFolding] Add folding for [de]interleave2, insert and extract #141301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,10 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::vector_reduce_smax:
case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_umax:
case Intrinsic::vector_extract:
case Intrinsic::vector_insert:
case Intrinsic::vector_interleave2:
case Intrinsic::vector_deinterleave2:
// Target intrinsics
case Intrinsic::amdgcn_perm:
case Intrinsic::amdgcn_wave_reduce_umin:
Expand Down Expand Up @@ -3734,6 +3738,98 @@ static Constant *ConstantFoldFixedVectorCall(
}
return nullptr;
}
case Intrinsic::vector_extract: {
auto *Vec = dyn_cast<Constant>(Operands[0]);
auto *Idx = cast<ConstantInt>(Operands[1]);
if (!Vec || !Idx || !isa<FixedVectorType>(Vec->getType()))
return nullptr;

unsigned NumElements = FVTy->getNumElements();
unsigned VecNumElements =
cast<FixedVectorType>(Vec->getType())->getNumElements();
unsigned StartingIndex = Idx->getZExtValue();

// Extracting entire vector is nop
if (NumElements == VecNumElements && StartingIndex == 0)
return Vec;

const unsigned NonPoisonNumElements =
std::min(StartingIndex + NumElements, VecNumElements);
for (unsigned I = StartingIndex; I < NonPoisonNumElements; ++I) {
Constant *Elt = Vec->getAggregateElement(I);
if (!Elt)
return nullptr;
Result[I - StartingIndex] = Elt;
}
Comment on lines +3756 to +3763
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistic nit, you could avoid NonPoisonNumElements and the second loop if you handle it in the main loop

Suggested change
const unsigned NonPoisonNumElements =
std::min(StartingIndex + NumElements, VecNumElements);
for (unsigned I = StartingIndex; I < NonPoisonNumElements; ++I) {
Constant *Elt = Vec->getAggregateElement(I);
if (!Elt)
return nullptr;
Result[I - StartingIndex] = Elt;
}
for (unsigned I = 0; I < NumElements; ++I) {
// Out of bounds elements are poison
if (StartingIndex + I >= VecNumElements) {
Result[I] = PoisonValue::get(FVTy->getElementType());
continue;
}
Constant *Elt = Vec->getAggregateElement(StartingIndex + I);
if (!Elt)
return nullptr;
Result[I] = Elt;
}


// Remaining elements are poison since they are out of bounds.
for (unsigned I = NonPoisonNumElements, E = StartingIndex + NumElements;
I < E; ++I)
Result[I - StartingIndex] = PoisonValue::get(FVTy->getElementType());

return ConstantVector::get(Result);
}
case Intrinsic::vector_insert: {
auto *Vec = dyn_cast<Constant>(Operands[0]);
auto *SubVec = dyn_cast<Constant>(Operands[1]);
Comment on lines +3773 to +3774
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I think you can drop the dyn_casts

Suggested change
auto *Vec = dyn_cast<Constant>(Operands[0]);
auto *SubVec = dyn_cast<Constant>(Operands[1]);
Constant *Vec = Operands[0];
Constant *SubVec = Operands[1];

auto *Idx = dyn_cast<ConstantInt>(Operands[2]);
if (!Vec || !SubVec || !Idx || !isa<FixedVectorType>(Vec->getType()))
return nullptr;

unsigned SubVecNumElements =
cast<FixedVectorType>(SubVec->getType())->getNumElements();
unsigned VecNumElements =
cast<FixedVectorType>(Vec->getType())->getNumElements();
unsigned IdxN = Idx->getZExtValue();
// Replacing entire vector with a subvec is nop
if (SubVecNumElements == VecNumElements && IdxN == 0)
return SubVec;

// Make sure indices are in the range [0, VecNumElements), otherwise the
// result is a poison value.
if (IdxN >= VecNumElements || IdxN + SubVecNumElements > VecNumElements ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SubVecNumElements is always >= 1 so if IdxN >= VecNumElements then IdxN + SubVecNumElements > VecNumElements. So I think you can remove the first check

Suggested change
if (IdxN >= VecNumElements || IdxN + SubVecNumElements > VecNumElements ||
if (IdxN + SubVecNumElements > VecNumElements ||

(IdxN && (SubVecNumElements % IdxN) != 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be if the index isn't a multiple of the vector length? I.e.

Suggested change
(IdxN && (SubVecNumElements % IdxN) != 0))
IdxN % SubVecNumElements)

Is it possible to add a test for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's not correct when IdxN = 3 and SubVecNumElements = 6 ?

idx must be a constant multiple of subvec’s known minimum vector length

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IdxN must be 0,6,12,18... right? So 3 should be poison and 3%6=3 would be true for this check

return PoisonValue::get(FVTy);

unsigned I = 0;
for (; I < IdxN; ++I) {
Constant *Elt = Vec->getAggregateElement(I);
if (!Elt)
return nullptr;
Result[I] = Elt;
}
for (; I < IdxN + SubVecNumElements; ++I) {
Constant *Elt = SubVec->getAggregateElement(I - IdxN);
if (!Elt)
return nullptr;
Result[I] = Elt;
}
for (; I < VecNumElements; ++I) {
Constant *Elt = Vec->getAggregateElement(I);
if (!Elt)
return nullptr;
Result[I] = Elt;
}
Comment on lines +3794 to +3812
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistic, I think it's easier to read if they're merged into one loop

Suggested change
unsigned I = 0;
for (; I < IdxN; ++I) {
Constant *Elt = Vec->getAggregateElement(I);
if (!Elt)
return nullptr;
Result[I] = Elt;
}
for (; I < IdxN + SubVecNumElements; ++I) {
Constant *Elt = SubVec->getAggregateElement(I - IdxN);
if (!Elt)
return nullptr;
Result[I] = Elt;
}
for (; I < VecNumElements; ++I) {
Constant *Elt = Vec->getAggregateElement(I);
if (!Elt)
return nullptr;
Result[I] = Elt;
}
for (unsigned I = 0; I < VecNumElements; ++I) {
Constant *Elt;
if (I >= IdxN && I < IdxN + SubVecNumElements)
Elt = SubVec->getAggregateElement(I - IdxN);
else
Elt = Vec->getAggregateElement(I);
if (!Elt)
return nullptr;
Result[I] = Elt;
}

return ConstantVector::get(Result);
}
case Intrinsic::vector_interleave2: {
auto *Vec0 = dyn_cast<Constant>(Operands[0]);
auto *Vec1 = dyn_cast<Constant>(Operands[1]);
if (!Vec0 || !Vec1)
return nullptr;
Comment on lines +3816 to +3819
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Operands is an ArrayRef<Constant *> already, are the dyn_casts redundant?


unsigned NumElements =
cast<FixedVectorType>(Vec0->getType())->getNumElements();
for (unsigned I = 0; I < NumElements; ++I) {
Constant *Elt0 = Vec0->getAggregateElement(I);
Constant *Elt1 = Vec1->getAggregateElement(I);
if (!Elt0 || !Elt1)
return nullptr;
Result[2 * I] = Elt0;
Result[2 * I + 1] = Elt1;
}
return ConstantVector::get(Result);
}
default:
break;
}
Expand Down Expand Up @@ -3872,6 +3968,26 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID,
return nullptr;
return ConstantStruct::get(StTy, SinResult, CosResult);
}
case Intrinsic::vector_deinterleave2: {
auto *Vec = dyn_cast<Constant>(Operands[0]);
if (!Vec)
return nullptr;
Comment on lines +3972 to +3974
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto *Vec = dyn_cast<Constant>(Operands[0]);
if (!Vec)
return nullptr;
Constant *Vec = Operands[0];


unsigned NumElements =
cast<VectorType>(Vec->getType())->getElementCount().getKnownMinValue() /
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Vec can be scalable here, should we check if it's scalable and bail? Otherwise we're relying on getAggregateElement to return nullptr

2;
SmallVector<Constant *, 4> Res0(NumElements), Res1(NumElements);
for (unsigned I = 0; I < NumElements; ++I) {
Constant *Elt0 = Vec->getAggregateElement(2 * I);
Constant *Elt1 = Vec->getAggregateElement(2 * I + 1);
if (!Elt0 || !Elt1)
return nullptr;
Res0[I] = Elt0;
Res1[I] = Elt1;
}
return ConstantStruct::get(StTy, ConstantVector::get(Res0),
ConstantVector::get(Res1));
}
default:
// TODO: Constant folding of vector intrinsics that fall through here does
// not work (e.g. overflow intrinsics)
Expand Down
100 changes: 100 additions & 0 deletions llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt < %s -passes=instsimplify,verify -disable-verify -S | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
; RUN: opt < %s -passes=instsimplify,verify -disable-verify -S | FileCheck %s
; RUN: opt < %s -passes=instsimplify -S | FileCheck %s

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable-verify needs to test poison value generation of vector.insert, vector.extract when writing/reading OOB


define <3 x i32> @fold_vector_extract() {
; CHECK-LABEL: define <3 x i32> @fold_vector_extract() {
; CHECK-NEXT: ret <3 x i32> <i32 3, i32 4, i32 5>
;
%1 = call <3 x i32> @llvm.vector.extract.v3i32.v8i32(<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, i64 3)
ret <3 x i32> %1
}

@a = external global i16, align 1

define <3 x i32> @fold_vector_extract_constexpr() {
; CHECK-LABEL: define <3 x i32> @fold_vector_extract_constexpr() {
; CHECK-NEXT: ret <3 x i32> <i32 ptrtoint (ptr @a to i32), i32 1, i32 2>
;
%1 = call <3 x i32> @llvm.vector.extract.v3i32.v8i32(<8 x i32> <i32 ptrtoint (ptr @a to i32), i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, i64 0)
ret <3 x i32> %1
}

define <3 x i32> @fold_vector_extract_last_poison() {
; CHECK-LABEL: define <3 x i32> @fold_vector_extract_last_poison() {
; CHECK-NEXT: ret <3 x i32> <i32 6, i32 7, i32 poison>
;
%1 = call <3 x i32> @llvm.vector.extract.v3i32.v8i32(<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, i64 6)
ret <3 x i32> %1
}

define <3 x i32> @fold_vector_extract_poison() {
; CHECK-LABEL: define <3 x i32> @fold_vector_extract_poison() {
; CHECK-NEXT: ret <3 x i32> poison
;
%1 = call <3 x i32> @llvm.vector.extract.v3i32.v8i32(<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, i64 8)
ret <3 x i32> %1
}

define <8 x i32> @fold_vector_extract_nop() {
; CHECK-LABEL: define <8 x i32> @fold_vector_extract_nop() {
; CHECK-NEXT: ret <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
;
%1 = call <8 x i32> @llvm.vector.extract.v3i32.v8i32(<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, i64 0)
ret <8 x i32> %1
}

define <8 x i32> @fold_vector_insert() {
; CHECK-LABEL: define <8 x i32> @fold_vector_insert() {
; CHECK-NEXT: ret <8 x i32> <i32 9, i32 10, i32 11, i32 12, i32 5, i32 6, i32 7, i32 8>
;
%1 = call <8 x i32> @llvm.vector.insert.v8i32(<8 x i32> <i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>, i64 0)
ret <8 x i32> %1
}

define <8 x i32> @fold_vector_insert_nop() {
; CHECK-LABEL: define <8 x i32> @fold_vector_insert_nop() {
; CHECK-NEXT: ret <8 x i32> <i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18>
;
%1 = call <8 x i32> @llvm.vector.insert.v8i32(<8 x i32> <i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>, <8 x i32> <i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18>, i64 0)
ret <8 x i32> %1
}

define <8 x i32> @fold_vector_insert_poison_idx_range() {
; CHECK-LABEL: define <8 x i32> @fold_vector_insert_poison_idx_range() {
; CHECK-NEXT: ret <8 x i32> poison
;
%1 = call <8 x i32> @llvm.vector.insert.v8i32(<8 x i32> <i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>, <6 x i32> <i32 9, i32 10, i32 11, i32 12, i32 13, i32 14>, i64 6)
ret <8 x i32> %1
}

define <8 x i32> @fold_vector_insert_poison_large_idx() {
; CHECK-LABEL: define <8 x i32> @fold_vector_insert_poison_large_idx() {
; CHECK-NEXT: ret <8 x i32> poison
;
%1 = call <8 x i32> @llvm.vector.insert.v8i32(<8 x i32> <i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>, <6 x i32> <i32 9, i32 10, i32 11, i32 12, i32 13, i32 14>, i64 -2)
ret <8 x i32> %1
}

define <8 x i32> @fold_vector_interleave2() {
; CHECK-LABEL: define <8 x i32> @fold_vector_interleave2() {
; CHECK-NEXT: ret <8 x i32> <i32 1, i32 5, i32 2, i32 6, i32 3, i32 7, i32 4, i32 8>
;
%1 = call<8 x i32> @llvm.vector.interleave2.v8i32(<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%1 = call<8 x i32> @llvm.vector.interleave2.v8i32(<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>)
%1 = call <8 x i32> @llvm.vector.interleave2.v8i32(<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>)

ret <8 x i32> %1
}

define {<4 x i32>, <4 x i32>} @fold_vector_deinterleav2() {
; CHECK-LABEL: define { <4 x i32>, <4 x i32> } @fold_vector_deinterleav2() {
; CHECK-NEXT: ret { <4 x i32>, <4 x i32> } { <4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8> }
;
%1 = call {<4 x i32>, <4 x i32>} @llvm.vector.deinterleave2.v4i32.v8i32(<8 x i32> <i32 1, i32 5, i32 2, i32 6, i32 3, i32 7, i32 4, i32 8>)
ret {<4 x i32>, <4 x i32>} %1
}

define {<vscale x 4 x i32>, <vscale x 4 x i32>} @fold_scalable_vector_deinterleav2() {
; CHECK-LABEL: define { <vscale x 4 x i32>, <vscale x 4 x i32> } @fold_scalable_vector_deinterleav2() {
; CHECK-NEXT: ret { <vscale x 4 x i32>, <vscale x 4 x i32> } zeroinitializer
;
%1 = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @llvm.vector.deinterleave2.v4i32.v8i32(<vscale x 8 x i32> zeroinitializer)
ret {<vscale x 4 x i32>, <vscale x 4 x i32>} %1
}