Skip to content

[SeparateConstOffsetFromGEP] Decompose constant xor operand if possible #135788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
Expand All @@ -190,6 +191,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <cstdint>
Expand All @@ -198,6 +200,8 @@
using namespace llvm;
using namespace llvm::PatternMatch;

#define DEBUG_TYPE "separate-offset-gep"

static cl::opt<bool> DisableSeparateConstOffsetFromGEP(
"disable-separate-const-offset-from-gep", cl::init(false),
cl::desc("Do not separate the constant offset from a GEP instruction"),
Expand Down Expand Up @@ -486,6 +490,42 @@ class SeparateConstOffsetFromGEP {
DenseMap<ExprKey, SmallVector<Instruction *, 2>> DominatingSubs;
};

/// A helper class that aims to convert xor operations into or operations when
/// their operands are disjoint and the result is used in a GEP's index. This
/// can then enable further GEP optimizations by effectively turning BaseVal |
/// Const into BaseVal + Const when they are disjoint, which
/// SeparateConstOffsetFromGEP can then process. This is a common pattern that
/// sets up a grid of memory accesses across a wave where each thread acesses
/// data at various offsets.
class XorToOrDisjointTransformer {
public:
XorToOrDisjointTransformer(Function &F, DominatorTree &DT,
const DataLayout &DL)
: F(F), DT(DT), DL(DL) {}

bool run();

private:
Function &F;
DominatorTree &DT;
const DataLayout &DL;
/// Maps a common operand to all Xor instructions
using XorOpList = SmallVector<std::pair<BinaryOperator *, APInt>, 8>;
using XorBaseValInst = DenseMap<Instruction *, XorOpList>;
XorBaseValInst XorGroups;

/// Checks if the given value has at least one GetElementPtr user
static bool hasGEPUser(const Value *V);

/// Helper function to check if BaseXor dominates all XORs in the group
bool dominatesAllXors(BinaryOperator *BaseXor, const XorOpList &XorsInGroup);

/// Processes a group of XOR instructions that share the same non-constant
/// base operand. Returns true if this group's processing modified the
/// function.
bool processXorGroup(Instruction *OriginalBaseInst, XorOpList &XorsInGroup);
};

} // end anonymous namespace

char SeparateConstOffsetFromGEPLegacyPass::ID = 0;
Expand Down Expand Up @@ -1162,6 +1202,154 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
return true;
}

// Helper function to check if an instruction has at least one GEP user
bool XorToOrDisjointTransformer::hasGEPUser(const Value *V) {
return llvm::any_of(V->users(), [](const User *U) {
return isa<llvm::GetElementPtrInst>(U);
});
}

bool XorToOrDisjointTransformer::dominatesAllXors(
BinaryOperator *BaseXor, const XorOpList &XorsInGroup) {
return llvm::all_of(XorsInGroup, [&](const auto &XorEntry) {
BinaryOperator *XorInst = XorEntry.first;
// Do not evaluate the BaseXor, otherwise we end up cloning it.
return XorInst == BaseXor || DT.dominates(BaseXor, XorInst);
});
}

bool XorToOrDisjointTransformer::processXorGroup(Instruction *OriginalBaseInst,
XorOpList &XorsInGroup) {
bool Changed = false;
if (XorsInGroup.size() <= 1)
return false;

// Sort XorsInGroup by the constant offset value in increasing order.
llvm::sort(XorsInGroup, [](const auto &A, const auto &B) {
return A.second.slt(B.second);
});

// Dominance check
// The "base" XOR for dominance purposes is the one with the smallest
// constant.
BinaryOperator *XorWithSmallConst = XorsInGroup[0].first;

if (!dominatesAllXors(XorWithSmallConst, XorsInGroup)) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
<< ": Cloning and inserting XOR with smallest constant ("
<< *XorWithSmallConst
<< ") as it does not dominate all other XORs"
<< " in function " << F.getName() << "\n");

BinaryOperator *ClonedXor =
cast<BinaryOperator>(XorWithSmallConst->clone());
ClonedXor->setName(XorWithSmallConst->getName() + ".dom_clone");
ClonedXor->insertAfter(OriginalBaseInst);
LLVM_DEBUG(dbgs() << " Cloned Inst: " << *ClonedXor << "\n");
Changed = true;
XorWithSmallConst = ClonedXor;
}

SmallVector<Instruction *, 8> InstructionsToErase;
const APInt SmallestConst =
cast<ConstantInt>(XorWithSmallConst->getOperand(1))->getValue();

// Main transformation loop: Iterate over the original XORs in the sorted
// group.
for (const auto &XorEntry : XorsInGroup) {
BinaryOperator *XorInst = XorEntry.first; // Original XOR instruction
const APInt ConstOffsetVal = XorEntry.second;

// Do not process the one with smallest constant as it is the base.
if (XorInst == XorWithSmallConst)
continue;

// Disjointness Check 1
APInt NewConstVal = ConstOffsetVal - SmallestConst;
if ((NewConstVal & SmallestConst) != 0) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Cannot transform XOR in function "
<< F.getName() << ":\n"
<< " New Const: " << NewConstVal
<< " Smallest Const: " << SmallestConst
<< " are not disjoint \n");
continue;
}

// Disjointness Check 2
if (MaskedValueIsZero(XorWithSmallConst, NewConstVal, SimplifyQuery(DL),
0)) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
<< ": Transforming XOR to OR (disjoint) in function "
<< F.getName() << ":\n"
<< " Xor: " << *XorInst << "\n"
<< " Base Val: " << *XorWithSmallConst << "\n"
<< " New Const: " << NewConstVal << "\n");

auto *NewOrInst = BinaryOperator::CreateDisjointOr(
XorWithSmallConst,
ConstantInt::get(OriginalBaseInst->getType(), NewConstVal),
XorInst->getName() + ".or_disjoint", XorInst->getIterator());

NewOrInst->copyMetadata(*XorInst);
XorInst->replaceAllUsesWith(NewOrInst);
LLVM_DEBUG(dbgs() << " New Inst: " << *NewOrInst << "\n");
InstructionsToErase.push_back(XorInst); // Mark original XOR for deletion

Changed = true;
} else {
LLVM_DEBUG(
dbgs() << DEBUG_TYPE
<< ": Cannot transform XOR (not proven disjoint) in function "
<< F.getName() << ":\n"
<< " Xor: " << *XorInst << "\n"
<< " Base Val: " << *XorWithSmallConst << "\n"
<< " New Const: " << NewConstVal << "\n");
}
}

for (Instruction *I : InstructionsToErase)
I->eraseFromParent();

return Changed;
}

// Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes
// the base for memory operations. This transformation is true under the
// following conditions
// Check 1 - B and C are disjoint.
// Check 2 - XOR(A,C) and B are disjoint.
//
// This transformation is beneficial particularly for GEPs because:
// 1. OR operations often map better to addressing modes than XOR
// 2. Disjoint OR operations preserve the semantics of the original XOR
// 3. This can enable further optimizations in the GEP offset folding pipeline
bool XorToOrDisjointTransformer::run() {
bool Changed = false;

// Collect all candidate XORs
for (Instruction &I : instructions(F)) {
Instruction *Op0 = nullptr;
ConstantInt *C1 = nullptr;
BinaryOperator *MatchedXorOp = nullptr;

// Attempt to match the instruction 'I' as XOR operation.
if (match(&I, m_CombineAnd(m_Xor(m_Instruction(Op0), m_ConstantInt(C1)),
m_BinOp(MatchedXorOp))) &&
hasGEPUser(MatchedXorOp))
XorGroups[Op0].emplace_back(MatchedXorOp, C1->getValue());
}

if (XorGroups.empty())
return false;

// Process each group of XORs
for (auto &[OriginalBaseInst, XorsInGroup] : XorGroups)
if (processXorGroup(OriginalBaseInst, XorsInGroup))
Changed = true;

return Changed;
}

bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
Expand All @@ -1181,6 +1369,11 @@ bool SeparateConstOffsetFromGEP::run(Function &F) {

DL = &F.getDataLayout();
bool Changed = false;

// Decompose xor in to "or disjoint" if possible.
XorToOrDisjointTransformer XorTransformer(F, *DT, *DL);
Changed |= XorTransformer.run();

for (BasicBlock &B : F) {
if (!DT->isReachableFromEntry(&B))
continue;
Expand Down
Loading
Loading