Skip to content

[LSV] Insert casts to vectorize mismatched types #134436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/Transforms/Utils/Local.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ void combineAAMetadata(Instruction *K, const Instruction *J);
/// replacement for the source instruction).
void copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source);

/// Copy the metadata from the source instruction to the destination (the
/// replace)
void copyMetadataForStore(StoreInst &Dest, const StoreInst &Source);

/// Patch the replacement so that it is not more restrictive than the value
/// being replaced. It assumes that the replacement does not get moved from
/// its original position.
Expand Down
45 changes: 45 additions & 0 deletions llvm/lib/Transforms/Utils/Local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,51 @@ void llvm::copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source) {
}
}

void llvm::copyMetadataForStore(StoreInst &Dest, const StoreInst &Source) {
SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
Source.getAllMetadata(MD);
MDBuilder MDB(Dest.getContext());
Type *NewType = Dest.getType();
for (const auto &MDPair : MD) {
unsigned ID = MDPair.first;
MDNode *N = MDPair.second;
switch (ID) {
case LLVMContext::MD_dbg:
case LLVMContext::MD_prof:
case LLVMContext::MD_tbaa_struct:
case LLVMContext::MD_alias_scope:
case LLVMContext::MD_noalias:
case LLVMContext::MD_nontemporal:
case LLVMContext::MD_access_group:
case LLVMContext::MD_noundef:
case LLVMContext::MD_noalias_addrspace:
case LLVMContext::MD_mem_parallel_loop_access:
Dest.setMetadata(ID, N);
break;

case LLVMContext::MD_tbaa: {
MDNode *NewTyNode =
MDB.createTBAAScalarTypeNode(NewType->getStructName(), N);
Dest.setMetadata(LLVMContext::MD_tbaa, NewTyNode);
break;
}
case LLVMContext::MD_nonnull:
break;

case LLVMContext::MD_align:
case LLVMContext::MD_dereferenceable:
case LLVMContext::MD_dereferenceable_or_null:
// These only directly apply if the new type is also a pointer.
if (NewType->isPointerTy())
Dest.setMetadata(ID, N);
break;

case LLVMContext::MD_range:
break;
}
}
}

void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
auto *ReplInst = dyn_cast<Instruction>(Repl);
if (!ReplInst)
Expand Down
124 changes: 123 additions & 1 deletion llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/PostOrderIterator.h"
Expand Down Expand Up @@ -237,6 +238,11 @@ void reorder(Instruction *I) {
}

class Vectorizer {

enum ClassTyDist {
Int, Float, Ptr, Other
};

Function &F;
AliasAnalysis &AA;
AssumptionCache &AC;
Expand Down Expand Up @@ -273,6 +279,17 @@ class Vectorizer {
bool runOnEquivalenceClass(const EqClassKey &EqClassKey,
ArrayRef<Instruction *> EqClass);

static int getTypeKind(Instruction *I) {
unsigned ID = I->getType()->getTypeID();
switch(ID) {
case Type::IntegerTyID:
case Type::FloatTyID:
case Type::PointerTyID:
return ID;
};
return -1;
}

/// Runs the vectorizer on one chain, i.e. a subset of an equivalence class
/// where all instructions access a known, constant offset from the first
/// instruction.
Expand Down Expand Up @@ -324,6 +341,10 @@ class Vectorizer {
Instruction *ChainElem, Instruction *ChainBegin,
const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);

/// Merge equivalence classes if casts could be inserted in one to match
/// the total bitwidth of the instructions.
void insertCastsToMergeClasses(EquivalenceClassMap &EQClasses);

/// Merges the equivalence classes if they have underlying objects that differ
/// by one level of indirection (i.e., one is a getelementptr and the other is
/// the base pointer in that getelementptr).
Expand Down Expand Up @@ -1308,6 +1329,107 @@ std::optional<APInt> Vectorizer::getConstantOffsetSelects(
return std::nullopt;
}

void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
if (EQClasses.size() < 2)
return;

// For each class, determine the most defined type. This information will
// help us determine the type instructions should be casted into.
MapVector<EqClassKey, unsigned> ClassToNewTyID;
for (const auto &C : EQClasses) {
int FirstTypeKind = getTypeKind(EQClasses[C.first][0]);
if (FirstTypeKind != -1 && all_of(EQClasses[C.first], [&](Instruction *I) {
return getTypeKind(I) == FirstTypeKind;
})) {
ClassToNewTyID[C.first] = FirstTypeKind;
}
}

// Loop over all equivalence classes and try to merge them. Keep track of
// classes that are merged into others.
DenseSet<EqClassKey> ClassesToErase;
for (auto EC1 : EQClasses) {
for (auto EC2 : EQClasses) {
// Skip if EC2 was already merged before, EC1 follows EC2 in the
// collection or EC1 is the same as EC2.
if (ClassesToErase.contains(EC2.first) || EC1 <= EC2 ||
EC1.first == EC2.first)
continue;

auto [Ptr1, AS1, TySize1, IsLoad1] = EC1.first;
auto [Ptr2, AS2, TySize2, IsLoad2] = EC2.first;

// Attempt to merge EC2 into EC1. Skip if the pointers, address spaces or
// whether the leader instruction is a load/store are different. Also skip
// if the scalar bitwidth of the first equivalence class is smaller than
// the second one to avoid reconsidering the same equivalence class pair.
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
continue;

// Ensure all instructions in EC2 can be bitcasted into NewTy.
/// TODO: NewTyBits is needed as stuctured binded variables cannot be
/// captured by a lambda until C++20.
auto NewTyBits = std::get<2>(EC1.first);
if (any_of(EC2.second, [&](Instruction *I) {
return DL.getTypeSizeInBits(getLoadStoreType(I)) != NewTyBits;
}))
continue;

// Create a new type for the equivalence class.
auto &Ctx = EC2.second[0]->getContext();
Type *NewTy = Type::getIntNTy(EC2.second[0]->getContext(), NewTyBits);
if (ClassToNewTyID[EC1.first] == Type::FloatTyID &&
ClassToNewTyID[EC2.first] == Type::FloatTyID) {
NewTy = Type::getFloatTy(Ctx);
} else if (ClassToNewTyID[EC1.first] == Type::PointerTyID &&
ClassToNewTyID[EC2.first] == Type::PointerTyID) {
NewTy = PointerType::get(Ctx, AS2);
}

for (Instruction *Inst : EC2.second) {
Value *Ptr = getLoadStorePointerOperand(Inst);
Type *OrigTy = Inst->getType();
if (OrigTy == NewTy)
continue;
if (auto *LI = dyn_cast<LoadInst>(Inst)) {
Builder.SetInsertPoint(LI->getIterator());
auto *NewLoad = Builder.CreateLoad(NewTy, Ptr);
auto *Cast = Builder.CreateBitOrPointerCast(
NewLoad, OrigTy, NewLoad->getName() + ".cast");
LI->replaceAllUsesWith(Cast);
copyMetadataForLoad(*NewLoad, *LI);
LI->eraseFromParent();
EQClasses[EC1.first].emplace_back(NewLoad);
} else {
auto *SI = cast<StoreInst>(Inst);
Builder.SetInsertPoint(SI->getIterator());
auto *Cast = Builder.CreateBitOrPointerCast(
SI->getValueOperand(), NewTy,
SI->getValueOperand()->getName() + ".cast");
auto *NewStore = Builder.CreateStore(
Cast, getLoadStorePointerOperand(SI), SI->isVolatile());
copyMetadataForStore(*NewStore, *SI);
SI->eraseFromParent();
EQClasses[EC1.first].emplace_back(NewStore);
}
}

// Sort the instructions in the equivalence class by their order in the
// basic block. This is important to ensure that the instructions are
// vectorized in the correct order.
std::sort(EQClasses[EC1.first].begin(), EQClasses[EC1.first].end(),
[](const Instruction *A, const Instruction *B) {
return A && B && A->comesBefore(B);
});
ClassesToErase.insert(EC2.first);
}
}

// Erase the equivalence classes that were merged into others.
for (auto Key : ClassesToErase)
EQClasses.erase(Key);
}

void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
if (EQClasses.size() < 2) // There is nothing to merge.
return;
Expand Down Expand Up @@ -1493,7 +1615,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
/*IsLoad=*/LI != nullptr}]
.emplace_back(&I);
}

insertCastsToMergeClasses(Ret);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is eagerly mutating the IR before vectorization is performed? Should try to only select a type, and coerce as part of the final vectorization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After merging equivalence classes, LSV converts them into chains at which point it is too late to introduce cast instructions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it too late to introduce cast instructions at that point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically, it may be possible. The necessary changes seem cumbersome to me, however. After collecting classes, LSV extracts chains from each class. These chains are split based on contiguity, alignment and MayAlias instructions, before vectorization. Merging chains after these splits would require careful handling of their instructions as vectorizeChain makes certain assumptions before determining the type of vectorized load/store.

I prefer to insert casts alongside mergeEquivalenceClasses(..) as gatherChains already understands what kinds of chains are handlable by vectorizeChain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important to not make pointless IR changes, and only do this if it vectorization will occur

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I will postpone the merging of class until vectorization then. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a deep dive into inserting casts before vectorizeChain.

  • Firstly, all gathered chains are guaranteed to be vectorized except if the chain only has one element.
  • gatherChains is responsible for calculating offsets. The splitChain functions further splits the gathered chains based on the legality of prospective vectorization. Inserting casts after these functions may break the legality of these chains as the offsets may no longer be correct. There are two ways to tackle this challenge:
    • Recompute the offsets and re-run the split chain functions - sounds like too much of an overhead to me.
    • Store chains in a heap-based data structure which can preserve legality, further demanding a lot of bookkeeping in order to replace the splitChain functions - this approach seems quite inscalable for longer chains.

mergeEquivalenceClasses(Ret);
return Ret;
}
Expand Down
Loading