Skip to content

Commit b15b5f0

Browse files
committed
[SandboxIR] Implement SwitchInst
This patch implements sandboxir::SwitchInst mirroring llvm::SwitchInst.
1 parent f33d519 commit b15b5f0

File tree

7 files changed

+441
-0
lines changed

7 files changed

+441
-0
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class CastInst;
131131
class PtrToIntInst;
132132
class BitCastInst;
133133
class AllocaInst;
134+
class SwitchInst;
134135
class UnaryOperator;
135136
class BinaryOperator;
136137
class AtomicRMWInst;
@@ -253,6 +254,7 @@ class Value {
253254
friend class InvokeInst; // For getting `Val`.
254255
friend class CallBrInst; // For getting `Val`.
255256
friend class GetElementPtrInst; // For getting `Val`.
257+
friend class SwitchInst; // For getting `Val`.
256258
friend class UnaryOperator; // For getting `Val`.
257259
friend class BinaryOperator; // For getting `Val`.
258260
friend class AtomicRMWInst; // For getting `Val`.
@@ -672,6 +674,7 @@ class Instruction : public sandboxir::User {
672674
friend class InvokeInst; // For getTopmostLLVMInstruction().
673675
friend class CallBrInst; // For getTopmostLLVMInstruction().
674676
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
677+
friend class SwitchInst; // For getTopmostLLVMInstruction().
675678
friend class UnaryOperator; // For getTopmostLLVMInstruction().
676679
friend class BinaryOperator; // For getTopmostLLVMInstruction().
677680
friend class AtomicRMWInst; // For getTopmostLLVMInstruction().
@@ -1477,6 +1480,91 @@ class GetElementPtrInst final
14771480
// TODO: Add missing member functions.
14781481
};
14791482

1483+
class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
1484+
public:
1485+
SwitchInst(llvm::SwitchInst *SI, Context &Ctx)
1486+
: SingleLLVMInstructionImpl(ClassID::Switch, Opcode::Switch, SI, Ctx) {}
1487+
1488+
static constexpr const unsigned DefaultPseudoIndex =
1489+
llvm::SwitchInst::DefaultPseudoIndex;
1490+
1491+
static SwitchInst *create(Value *V, BasicBlock *Dest, unsigned NumCases,
1492+
BasicBlock::iterator WhereIt, BasicBlock *WhereBB,
1493+
Context &Ctx, const Twine &Name = "");
1494+
1495+
Value *getCondition() const;
1496+
void setCondition(Value *V);
1497+
BasicBlock *getDefaultDest() const;
1498+
bool defaultDestUndefined() const {
1499+
return cast<llvm::SwitchInst>(Val)->defaultDestUndefined();
1500+
}
1501+
void setDefaultDest(BasicBlock *DefaultCase);
1502+
unsigned getNumCases() const {
1503+
return cast<llvm::SwitchInst>(Val)->getNumCases();
1504+
}
1505+
1506+
using CaseHandle =
1507+
llvm::SwitchInst::CaseHandleImpl<SwitchInst, ConstantInt, BasicBlock>;
1508+
using ConstCaseHandle =
1509+
llvm::SwitchInst::CaseHandleImpl<const SwitchInst, const ConstantInt,
1510+
const BasicBlock>;
1511+
using CaseIt = llvm::SwitchInst::CaseIteratorImpl<CaseHandle>;
1512+
using ConstCaseIt = llvm::SwitchInst::CaseIteratorImpl<ConstCaseHandle>;
1513+
1514+
/// Returns a read/write iterator that points to the first case in the
1515+
/// SwitchInst.
1516+
CaseIt case_begin() { return CaseIt(this, 0); }
1517+
ConstCaseIt case_begin() const { return ConstCaseIt(this, 0); }
1518+
/// Returns a read/write iterator that points one past the last in the
1519+
/// SwitchInst.
1520+
CaseIt case_end() { return CaseIt(this, getNumCases()); }
1521+
ConstCaseIt case_end() const { return ConstCaseIt(this, getNumCases()); }
1522+
/// Iteration adapter for range-for loops.
1523+
iterator_range<CaseIt> cases() {
1524+
return make_range(case_begin(), case_end());
1525+
}
1526+
iterator_range<ConstCaseIt> cases() const {
1527+
return make_range(case_begin(), case_end());
1528+
}
1529+
CaseIt case_default() { return CaseIt(this, DefaultPseudoIndex); }
1530+
ConstCaseIt case_default() const {
1531+
return ConstCaseIt(this, DefaultPseudoIndex);
1532+
}
1533+
CaseIt findCaseValue(const ConstantInt *C) {
1534+
return CaseIt(
1535+
this,
1536+
const_cast<const SwitchInst *>(this)->findCaseValue(C)->getCaseIndex());
1537+
}
1538+
ConstCaseIt findCaseValue(const ConstantInt *C) const {
1539+
ConstCaseIt I = llvm::find_if(cases(), [C](const ConstCaseHandle &Case) {
1540+
return Case.getCaseValue() == C;
1541+
});
1542+
if (I != case_end())
1543+
return I;
1544+
return case_default();
1545+
}
1546+
ConstantInt *findCaseDest(BasicBlock *BB);
1547+
1548+
void addCase(ConstantInt *OnVal, BasicBlock *Dest);
1549+
/// This method removes the specified case and its successor from the switch
1550+
/// instruction. Note that this operation may reorder the remaining cases at
1551+
/// index idx and above.
1552+
/// Note:
1553+
/// This action invalidates iterators for all cases following the one removed,
1554+
/// including the case_end() iterator. It returns an iterator for the next
1555+
/// case.
1556+
CaseIt removeCase(CaseIt It);
1557+
1558+
unsigned getNumSuccessors() const {
1559+
return cast<llvm::SwitchInst>(Val)->getNumSuccessors();
1560+
}
1561+
BasicBlock *getSuccessor(unsigned Idx) const;
1562+
void setSuccessor(unsigned Idx, BasicBlock *NewSucc);
1563+
static bool classof(const Value *From) {
1564+
return From->getSubclassID() == ClassID::Switch;
1565+
}
1566+
};
1567+
14801568
class UnaryOperator : public UnaryInstruction {
14811569
static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
14821570
switch (UnOp) {
@@ -2113,6 +2201,8 @@ class Context {
21132201
friend CallBrInst; // For createCallBrInst()
21142202
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
21152203
friend GetElementPtrInst; // For createGetElementPtrInst()
2204+
SwitchInst *createSwitchInst(llvm::SwitchInst *I);
2205+
friend SwitchInst; // For createSwitchInst()
21162206
UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
21172207
friend UnaryOperator; // For createUnaryOperator()
21182208
BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ DEF_INSTR(Call, OP(Call), CallInst)
4646
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
4747
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
4848
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
49+
DEF_INSTR(Switch, OP(Switch), SwitchInst)
4950
DEF_INSTR(UnOp, OPCODES( \
5051
OP(FNeg) \
5152
), UnaryOperator)

llvm/include/llvm/SandboxIR/Tracker.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class StoreInst;
5959
class Instruction;
6060
class Tracker;
6161
class AllocaInst;
62+
class SwitchInst;
63+
class ConstantInt;
6264

6365
/// The base class for IR Change classes.
6466
class IRChangeBase {
@@ -261,6 +263,37 @@ class GenericSetterWithIdx final : public IRChangeBase {
261263
#endif
262264
};
263265

266+
class SwitchAddCase : public IRChangeBase {
267+
SwitchInst *Switch;
268+
ConstantInt *Val;
269+
270+
public:
271+
SwitchAddCase(SwitchInst *Switch, ConstantInt *Val)
272+
: Switch(Switch), Val(Val) {}
273+
void revert(Tracker &Tracker) final;
274+
void accept() final {}
275+
#ifndef NDEBUG
276+
void dump(raw_ostream &OS) const final { OS << "SwitchAddCase"; }
277+
LLVM_DUMP_METHOD void dump() const final;
278+
#endif // NDEBUG
279+
};
280+
281+
class SwitchRemoveCase : public IRChangeBase {
282+
SwitchInst *Switch;
283+
ConstantInt *Val;
284+
BasicBlock *Dest;
285+
286+
public:
287+
SwitchRemoveCase(SwitchInst *Switch, ConstantInt *Val, BasicBlock *Dest)
288+
: Switch(Switch), Val(Val), Dest(Dest) {}
289+
void revert(Tracker &Tracker) final;
290+
void accept() final {}
291+
#ifndef NDEBUG
292+
void dump(raw_ostream &OS) const final { OS << "SwitchRemoveCase"; }
293+
LLVM_DUMP_METHOD void dump() const final;
294+
#endif // NDEBUG
295+
};
296+
264297
class MoveInstr : public IRChangeBase {
265298
/// The instruction that moved.
266299
Instruction *MovedI;

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,84 @@ static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
12361236
}
12371237
}
12381238

1239+
SwitchInst *SwitchInst::create(Value *V, BasicBlock *Dest, unsigned NumCases,
1240+
BasicBlock::iterator WhereIt,
1241+
BasicBlock *WhereBB, Context &Ctx,
1242+
const Twine &Name) {
1243+
auto &Builder = Ctx.getLLVMIRBuilder();
1244+
if (WhereIt != WhereBB->end())
1245+
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
1246+
else
1247+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
1248+
llvm::SwitchInst *LLVMSwitch =
1249+
Builder.CreateSwitch(V->Val, cast<llvm::BasicBlock>(Dest->Val), NumCases);
1250+
return Ctx.createSwitchInst(LLVMSwitch);
1251+
}
1252+
1253+
Value *SwitchInst::getCondition() const {
1254+
return Ctx.getValue(cast<llvm::SwitchInst>(Val)->getCondition());
1255+
}
1256+
1257+
void SwitchInst::setCondition(Value *V) {
1258+
Ctx.getTracker()
1259+
.emplaceIfTracking<
1260+
GenericSetter<&SwitchInst::getCondition, &SwitchInst::setCondition>>(
1261+
this);
1262+
cast<llvm::SwitchInst>(Val)->setCondition(V->Val);
1263+
}
1264+
1265+
BasicBlock *SwitchInst::getDefaultDest() const {
1266+
return cast<BasicBlock>(
1267+
Ctx.getValue(cast<llvm::SwitchInst>(Val)->getDefaultDest()));
1268+
}
1269+
1270+
void SwitchInst::setDefaultDest(BasicBlock *DefaultCase) {
1271+
Ctx.getTracker()
1272+
.emplaceIfTracking<GenericSetter<&SwitchInst::getDefaultDest,
1273+
&SwitchInst::setDefaultDest>>(this);
1274+
cast<llvm::SwitchInst>(Val)->setDefaultDest(
1275+
cast<llvm::BasicBlock>(DefaultCase->Val));
1276+
}
1277+
ConstantInt *SwitchInst::findCaseDest(BasicBlock *BB) {
1278+
auto *LLVMC = cast<llvm::SwitchInst>(Val)->findCaseDest(
1279+
cast<llvm::BasicBlock>(BB->Val));
1280+
return LLVMC != nullptr ? cast<ConstantInt>(Ctx.getValue(LLVMC)) : nullptr;
1281+
}
1282+
1283+
void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) {
1284+
Ctx.getTracker().emplaceIfTracking<SwitchAddCase>(this, OnVal);
1285+
// TODO: Track this!
1286+
cast<llvm::SwitchInst>(Val)->addCase(cast<llvm::ConstantInt>(OnVal->Val),
1287+
cast<llvm::BasicBlock>(Dest->Val));
1288+
}
1289+
1290+
SwitchInst::CaseIt SwitchInst::removeCase(CaseIt It) {
1291+
auto &Case = *It;
1292+
Ctx.getTracker().emplaceIfTracking<SwitchRemoveCase>(
1293+
this, Case.getCaseValue(), Case.getCaseSuccessor());
1294+
1295+
auto *LLVMSwitch = cast<llvm::SwitchInst>(Val);
1296+
unsigned CaseNum = It - case_begin();
1297+
llvm::SwitchInst::CaseIt LLVMIt(LLVMSwitch, CaseNum);
1298+
auto LLVMCaseIt = LLVMSwitch->removeCase(LLVMIt);
1299+
unsigned Num = LLVMCaseIt - LLVMSwitch->case_begin();
1300+
return CaseIt(this, Num);
1301+
}
1302+
1303+
BasicBlock *SwitchInst::getSuccessor(unsigned Idx) const {
1304+
return cast<BasicBlock>(
1305+
Ctx.getValue(cast<llvm::SwitchInst>(Val)->getSuccessor(Idx)));
1306+
}
1307+
1308+
void SwitchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
1309+
Ctx.getTracker()
1310+
.emplaceIfTracking<GenericSetterWithIdx<&SwitchInst::getSuccessor,
1311+
&SwitchInst::setSuccessor>>(this,
1312+
Idx);
1313+
cast<llvm::SwitchInst>(Val)->setSuccessor(
1314+
Idx, cast<llvm::BasicBlock>(NewSucc->Val));
1315+
}
1316+
12391317
Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
12401318
BBIterator WhereIt, BasicBlock *WhereBB,
12411319
Context &Ctx, const Twine &Name) {
@@ -1875,6 +1953,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
18751953
new GetElementPtrInst(LLVMGEP, *this));
18761954
return It->second.get();
18771955
}
1956+
case llvm::Instruction::Switch: {
1957+
auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
1958+
It->second =
1959+
std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
1960+
return It->second.get();
1961+
}
18781962
case llvm::Instruction::FNeg: {
18791963
auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
18801964
It->second = std::unique_ptr<UnaryOperator>(
@@ -2033,6 +2117,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
20332117
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
20342118
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
20352119
}
2120+
SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
2121+
auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
2122+
return cast<SwitchInst>(registerValue(std::move(NewPtr)));
2123+
}
20362124
UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
20372125
auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
20382126
return cast<UnaryOperator>(registerValue(std::move(NewPtr)));

llvm/lib/SandboxIR/Tracker.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,27 @@ void RemoveFromParent::dump() const {
160160
}
161161
#endif
162162

163+
void SwitchRemoveCase::revert(Tracker &Tracker) { Switch->addCase(Val, Dest); }
164+
165+
#ifndef NDEBUG
166+
void SwitchRemoveCase::dump() const {
167+
dump(dbgs());
168+
dbgs() << "\n";
169+
}
170+
#endif // NDEBUG
171+
172+
void SwitchAddCase::revert(Tracker &Tracker) {
173+
auto It = Switch->findCaseValue(Val);
174+
Switch->removeCase(It);
175+
}
176+
177+
#ifndef NDEBUG
178+
void SwitchAddCase::dump() const {
179+
dump(dbgs());
180+
dbgs() << "\n";
181+
}
182+
#endif // NDEBUG
183+
163184
MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) {
164185
if (auto *NextI = MovedI->getNextNode())
165186
NextInstrOrBB = NextI;

0 commit comments

Comments
 (0)