Skip to content

Commit a2af374

Browse files
authored
[SelectionDAG] Add space-optimized forms of OPC_CheckPredicate (#77763)
We record the usage of each `Predicate` and sort them by usage. For the top 8 `Predicate`s, we will emit a `PC_CheckPredicateN` to save one byte. Overall this reduces the llc binary size with all in-tree targets by about 61K. This is a recommit of 1a57927, which was reverted in bc98c31. The CI failures occurred when doing expensive checks (with option `LLVM_ENABLE_EXPENSIVE_CHECKS` being ON). The key point here is that we need stable sorting result in the test, but doing expensive checks uncovered the non-determinism of `llvm::sort`. So `llvm::sort` is changed to `llvm::stable_sort` in this revised patch. And we use `llvm::MapVector` to keep insertion order.
1 parent 9e40ba0 commit a2af374

File tree

5 files changed

+96
-46
lines changed

5 files changed

+96
-46
lines changed

llvm/include/llvm/CodeGen/SelectionDAGISel.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ class SelectionDAGISel : public MachineFunctionPass {
169169
OPC_CheckPatternPredicate7,
170170
OPC_CheckPatternPredicateTwoByte,
171171
OPC_CheckPredicate,
172+
OPC_CheckPredicate0,
173+
OPC_CheckPredicate1,
174+
OPC_CheckPredicate2,
175+
OPC_CheckPredicate3,
176+
OPC_CheckPredicate4,
177+
OPC_CheckPredicate5,
178+
OPC_CheckPredicate6,
179+
OPC_CheckPredicate7,
172180
OPC_CheckPredicateWithOperands,
173181
OPC_CheckOpcode,
174182
OPC_SwitchOpcode,

llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,9 +2712,13 @@ CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
27122712

27132713
/// CheckNodePredicate - Implements OP_CheckNodePredicate.
27142714
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
2715-
CheckNodePredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
2716-
const SelectionDAGISel &SDISel, SDNode *N) {
2717-
return SDISel.CheckNodePredicate(N, MatcherTable[MatcherIndex++]);
2715+
CheckNodePredicate(unsigned Opcode, const unsigned char *MatcherTable,
2716+
unsigned &MatcherIndex, const SelectionDAGISel &SDISel,
2717+
SDNode *N) {
2718+
unsigned PredNo = Opcode == SelectionDAGISel::OPC_CheckPredicate
2719+
? MatcherTable[MatcherIndex++]
2720+
: Opcode - SelectionDAGISel::OPC_CheckPredicate0;
2721+
return SDISel.CheckNodePredicate(N, PredNo);
27182722
}
27192723

27202724
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
@@ -2868,7 +2872,15 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
28682872
Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
28692873
return Index;
28702874
case SelectionDAGISel::OPC_CheckPredicate:
2871-
Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
2875+
case SelectionDAGISel::OPC_CheckPredicate0:
2876+
case SelectionDAGISel::OPC_CheckPredicate1:
2877+
case SelectionDAGISel::OPC_CheckPredicate2:
2878+
case SelectionDAGISel::OPC_CheckPredicate3:
2879+
case SelectionDAGISel::OPC_CheckPredicate4:
2880+
case SelectionDAGISel::OPC_CheckPredicate5:
2881+
case SelectionDAGISel::OPC_CheckPredicate6:
2882+
case SelectionDAGISel::OPC_CheckPredicate7:
2883+
Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N.getNode());
28722884
return Index;
28732885
case SelectionDAGISel::OPC_CheckOpcode:
28742886
Result = !::CheckOpcode(Table, Index, N.getNode());
@@ -3359,8 +3371,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
33593371
if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
33603372
break;
33613373
continue;
3374+
case SelectionDAGISel::OPC_CheckPredicate0:
3375+
case SelectionDAGISel::OPC_CheckPredicate1:
3376+
case SelectionDAGISel::OPC_CheckPredicate2:
3377+
case SelectionDAGISel::OPC_CheckPredicate3:
3378+
case SelectionDAGISel::OPC_CheckPredicate4:
3379+
case SelectionDAGISel::OPC_CheckPredicate5:
3380+
case SelectionDAGISel::OPC_CheckPredicate6:
3381+
case SelectionDAGISel::OPC_CheckPredicate7:
33623382
case OPC_CheckPredicate:
3363-
if (!::CheckNodePredicate(MatcherTable, MatcherIndex, *this,
3383+
if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this,
33643384
N.getNode()))
33653385
break;
33663386
continue;

llvm/test/TableGen/address-space-patfrags.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def inst_d : Instruction {
4646
let InOperandList = (ins GPR32:$src0, GPR32:$src1);
4747
}
4848

49-
// SDAG: case 2: {
49+
// SDAG: case 1: {
5050
// SDAG-NEXT: // Predicate_pat_frag_b
5151
// SDAG-NEXT: // Predicate_truncstorei16_addrspace
5252
// SDAG-NEXT: SDNode *N = Node;
@@ -69,7 +69,7 @@ def : Pat <
6969
>;
7070

7171

72-
// SDAG: case 3: {
72+
// SDAG: case 6: {
7373
// SDAG: // Predicate_pat_frag_a
7474
// SDAG-NEXT: SDNode *N = Node;
7575
// SDAG-NEXT: (void)N;

llvm/test/TableGen/predicate-patfags.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def TGTmul24_oneuse : PatFrag<
3939
}
4040

4141
// SDAG: OPC_CheckOpcode, TARGET_VAL(ISD::INTRINSIC_W_CHAIN),
42-
// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
42+
// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse
4343

4444
// SDAG: OPC_CheckOpcode, TARGET_VAL(TargetISD::MUL24),
45-
// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
45+
// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse
4646

4747
// GISEL: GIM_CheckOpcode, /*MI*/1, GIMT_Encode2(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS),
4848
// GISEL: GIM_CheckIntrinsicID, /*MI*/1, /*Op*/1, GIMT_Encode2(Intrinsic::tgt_mul24),

llvm/utils/TableGen/DAGISelMatcherEmitter.cpp

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ class MatcherTableEmitter {
5252

5353
SmallVector<unsigned, Matcher::HighestKind+1> OpcodeCounts;
5454

55-
DenseMap<TreePattern *, unsigned> NodePredicateMap;
56-
std::vector<TreePredicateFn> NodePredicates;
57-
std::vector<TreePredicateFn> NodePredicatesWithOperands;
55+
std::vector<TreePattern *> NodePredicates;
56+
std::vector<TreePattern *> NodePredicatesWithOperands;
5857

5958
// We de-duplicate the predicates by code string, and use this map to track
6059
// all the patterns with "identical" predicates.
@@ -87,7 +86,9 @@ class MatcherTableEmitter {
8786
// Record the usage of ComplexPattern.
8887
MapVector<const ComplexPattern *, unsigned> ComplexPatternUsage;
8988
// Record the usage of PatternPredicate.
90-
std::map<StringRef, unsigned> PatternPredicateUsage;
89+
MapVector<StringRef, unsigned> PatternPredicateUsage;
90+
// Record the usage of Predicate.
91+
MapVector<TreePattern *, unsigned> PredicateUsage;
9192

9293
// Iterate the whole MatcherTable once and do some statistics.
9394
std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
@@ -105,6 +106,8 @@ class MatcherTableEmitter {
105106
++ComplexPatternUsage[&CPM->getPattern()];
106107
else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
107108
++PatternPredicateUsage[CPPM->getPredicate()];
109+
else if (auto *PM = dyn_cast<CheckPredicateMatcher>(N))
110+
++PredicateUsage[PM->getPredicate().getOrigPatFragRecord()];
108111
N = N->getNext();
109112
}
110113
};
@@ -127,6 +130,40 @@ class MatcherTableEmitter {
127130
});
128131
for (const auto &PatternPredicate : PatternPredicateList)
129132
PatternPredicates.push_back(PatternPredicate.first);
133+
134+
// Sort Predicates by usage.
135+
// Merge predicates with same code.
136+
for (const auto &Usage : PredicateUsage) {
137+
TreePattern *TP = Usage.first;
138+
TreePredicateFn Pred(TP);
139+
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()].push_back(TP);
140+
}
141+
142+
std::vector<std::pair<TreePattern *, unsigned>> PredicateList;
143+
// Sum the usage.
144+
for (auto &Predicate : NodePredicatesByCodeToRun) {
145+
TinyPtrVector<TreePattern *> &TPs = Predicate.second;
146+
stable_sort(TPs, [](const auto *A, const auto *B) {
147+
return A->getRecord()->getName() < B->getRecord()->getName();
148+
});
149+
unsigned Uses = 0;
150+
for (TreePattern *TP : TPs)
151+
Uses += PredicateUsage[TP];
152+
153+
// We only add the first predicate here since they are with the same code.
154+
PredicateList.push_back({TPs[0], Uses});
155+
}
156+
157+
stable_sort(PredicateList, [](const auto &A, const auto &B) {
158+
return A.second > B.second;
159+
});
160+
for (const auto &Predicate : PredicateList) {
161+
TreePattern *TP = Predicate.first;
162+
if (TreePredicateFn(TP).usesOperands())
163+
NodePredicatesWithOperands.push_back(TP);
164+
else
165+
NodePredicates.push_back(TP);
166+
}
130167
}
131168

132169
unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
@@ -141,7 +178,7 @@ class MatcherTableEmitter {
141178
void EmitPatternMatchTable(raw_ostream &OS);
142179

143180
private:
144-
void EmitNodePredicatesFunction(const std::vector<TreePredicateFn> &Preds,
181+
void EmitNodePredicatesFunction(const std::vector<TreePattern *> &Preds,
145182
StringRef Decl, raw_ostream &OS);
146183

147184
unsigned SizeMatcher(Matcher *N, raw_ostream &OS);
@@ -150,33 +187,13 @@ class MatcherTableEmitter {
150187
raw_ostream &OS);
151188

152189
unsigned getNodePredicate(TreePredicateFn Pred) {
153-
TreePattern *TP = Pred.getOrigPatFragRecord();
154-
unsigned &Entry = NodePredicateMap[TP];
155-
if (Entry == 0) {
156-
TinyPtrVector<TreePattern *> &SameCodePreds =
157-
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()];
158-
if (SameCodePreds.empty()) {
159-
// We've never seen a predicate with the same code: allocate an entry.
160-
if (Pred.usesOperands()) {
161-
NodePredicatesWithOperands.push_back(Pred);
162-
Entry = NodePredicatesWithOperands.size();
163-
} else {
164-
NodePredicates.push_back(Pred);
165-
Entry = NodePredicates.size();
166-
}
167-
} else {
168-
// We did see an identical predicate: re-use it.
169-
Entry = NodePredicateMap[SameCodePreds.front()];
170-
assert(Entry != 0);
171-
assert(TreePredicateFn(SameCodePreds.front()).usesOperands() ==
172-
Pred.usesOperands() &&
173-
"PatFrags with some code must have same usesOperands setting");
174-
}
175-
// In both cases, we've never seen this particular predicate before, so
176-
// mark it in the list of predicates sharing the same code.
177-
SameCodePreds.push_back(TP);
178-
}
179-
return Entry-1;
190+
// We use the first predicate.
191+
TreePattern *PredPat =
192+
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()][0];
193+
return Pred.usesOperands()
194+
? llvm::find(NodePredicatesWithOperands, PredPat) -
195+
NodePredicatesWithOperands.begin()
196+
: llvm::find(NodePredicates, PredPat) - NodePredicates.begin();
180197
}
181198

182199
unsigned getPatternPredicate(StringRef PredName) {
@@ -531,6 +548,7 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
531548
case Matcher::CheckPredicate: {
532549
TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
533550
unsigned OperandBytes = 0;
551+
unsigned PredNo = getNodePredicate(Pred);
534552

535553
if (Pred.usesOperands()) {
536554
unsigned NumOps = cast<CheckPredicateMatcher>(N)->getNumOperands();
@@ -539,10 +557,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
539557
OS << cast<CheckPredicateMatcher>(N)->getOperandNo(i) << ", ";
540558
OperandBytes = 1 + NumOps;
541559
} else {
542-
OS << "OPC_CheckPredicate, ";
560+
if (PredNo < 8) {
561+
OperandBytes = -1;
562+
OS << "OPC_CheckPredicate" << PredNo << ", ";
563+
} else
564+
OS << "OPC_CheckPredicate, ";
543565
}
544566

545-
OS << getNodePredicate(Pred) << ',';
567+
if (PredNo >= 8 || Pred.usesOperands())
568+
OS << PredNo << ',';
546569
if (!OmitComments)
547570
OS << " // " << Pred.getFnName();
548571
OS << '\n';
@@ -1031,8 +1054,7 @@ EmitMatcherList(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
10311054
}
10321055

10331056
void MatcherTableEmitter::EmitNodePredicatesFunction(
1034-
const std::vector<TreePredicateFn> &Preds, StringRef Decl,
1035-
raw_ostream &OS) {
1057+
const std::vector<TreePattern *> &Preds, StringRef Decl, raw_ostream &OS) {
10361058
if (Preds.empty())
10371059
return;
10381060

@@ -1042,7 +1064,7 @@ void MatcherTableEmitter::EmitNodePredicatesFunction(
10421064
OS << " default: llvm_unreachable(\"Invalid predicate in table?\");\n";
10431065
for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
10441066
// Emit the predicate code corresponding to this pattern.
1045-
const TreePredicateFn PredFn = Preds[i];
1067+
TreePredicateFn PredFn(Preds[i]);
10461068
assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
10471069
std::string PredFnCodeStr = PredFn.getCodeToRunOnSDNode();
10481070

0 commit comments

Comments
 (0)