Skip to content

Commit 7b2ac8f

Browse files
authored
[Matrix] Pass ShapeInfo to Visit* methods (NFC). (#142487)
They all require it now.
1 parent b88e8cc commit 7b2ac8f

File tree

1 file changed

+30
-46
lines changed

1 file changed

+30
-46
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 30 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,19 +1056,20 @@ class LowerMatrixIntrinsics {
10561056

10571057
IRBuilder<> Builder(Inst);
10581058

1059-
if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1060-
Changed |= VisitCallInst(CInst);
1059+
const ShapeInfo &SI = ShapeMap.at(Inst);
10611060

10621061
Value *Op1;
10631062
Value *Op2;
10641063
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1065-
VisitBinaryOperator(BinOp);
1064+
VisitBinaryOperator(BinOp, SI);
10661065
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1067-
VisitUnaryOperator(UnOp);
1066+
VisitUnaryOperator(UnOp, SI);
1067+
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1068+
VisitCallInst(CInst);
10681069
else if (match(Inst, m_Load(m_Value(Op1))))
1069-
VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
1070+
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
10701071
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1071-
VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1072+
VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
10721073
else
10731074
continue;
10741075
Changed = true;
@@ -1109,10 +1110,10 @@ class LowerMatrixIntrinsics {
11091110
return Changed;
11101111
}
11111112

1112-
/// Replace intrinsic calls
1113-
bool VisitCallInst(CallInst *Inst) {
1114-
if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1115-
return false;
1113+
/// Replace intrinsic calls.
1114+
void VisitCallInst(CallInst *Inst) {
1115+
assert(Inst->getCalledFunction() &&
1116+
Inst->getCalledFunction()->isIntrinsic());
11161117

11171118
switch (Inst->getCalledFunction()->getIntrinsicID()) {
11181119
case Intrinsic::matrix_multiply:
@@ -1128,9 +1129,9 @@ class LowerMatrixIntrinsics {
11281129
LowerColumnMajorStore(Inst);
11291130
break;
11301131
default:
1131-
return false;
1132+
llvm_unreachable(
1133+
"only intrinsics supporting shape info should be seen here");
11321134
}
1133-
return true;
11341135
}
11351136

11361137
/// Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -2107,48 +2108,36 @@ class LowerMatrixIntrinsics {
21072108
Builder);
21082109
}
21092110

2110-
/// Lower load instructions, if shape information is available.
2111-
void VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2112-
auto I = ShapeMap.find(Inst);
2113-
assert(I != ShapeMap.end() &&
2114-
"must only visit instructions with shape info");
2115-
LowerLoad(Inst, Ptr, Inst->getAlign(),
2116-
Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2117-
I->second);
2111+
/// Lower load instructions.
2112+
void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2113+
IRBuilder<> &Builder) {
2114+
LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
2115+
Inst->isVolatile(), SI);
21182116
}
21192117

2120-
void VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
2121-
IRBuilder<> &Builder) {
2122-
auto I = ShapeMap.find(Inst);
2123-
assert(I != ShapeMap.end() &&
2124-
"must only visit instructions with shape info");
2118+
void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2119+
Value *Ptr, IRBuilder<> &Builder) {
21252120
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2126-
Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2127-
I->second);
2121+
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
21282122
}
21292123

2130-
/// Lower binary operators, if shape information is available.
2131-
void VisitBinaryOperator(BinaryOperator *Inst) {
2132-
auto I = ShapeMap.find(Inst);
2133-
assert(I != ShapeMap.end() &&
2134-
"must only visit instructions with shape info");
2135-
2124+
/// Lower binary operators.
2125+
void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
21362126
Value *Lhs = Inst->getOperand(0);
21372127
Value *Rhs = Inst->getOperand(1);
21382128

21392129
IRBuilder<> Builder(Inst);
2140-
ShapeInfo &Shape = I->second;
21412130

21422131
MatrixTy Result;
2143-
MatrixTy A = getMatrix(Lhs, Shape, Builder);
2144-
MatrixTy B = getMatrix(Rhs, Shape, Builder);
2132+
MatrixTy A = getMatrix(Lhs, SI, Builder);
2133+
MatrixTy B = getMatrix(Rhs, SI, Builder);
21452134
assert(A.isColumnMajor() == B.isColumnMajor() &&
21462135
Result.isColumnMajor() == A.isColumnMajor() &&
21472136
"operands must agree on matrix layout");
21482137

21492138
Builder.setFastMathFlags(getFastMathFlags(Inst));
21502139

2151-
for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2140+
for (unsigned I = 0; I < SI.getNumVectors(); ++I)
21522141
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
21532142
B.getVector(I)));
21542143

@@ -2158,19 +2147,14 @@ class LowerMatrixIntrinsics {
21582147
Builder);
21592148
}
21602149

2161-
/// Lower unary operators, if shape information is available.
2162-
void VisitUnaryOperator(UnaryOperator *Inst) {
2163-
auto I = ShapeMap.find(Inst);
2164-
assert(I != ShapeMap.end() &&
2165-
"must only visit instructions with shape info");
2166-
2150+
/// Lower unary operators.
2151+
void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
21672152
Value *Op = Inst->getOperand(0);
21682153

21692154
IRBuilder<> Builder(Inst);
2170-
ShapeInfo &Shape = I->second;
21712155

21722156
MatrixTy Result;
2173-
MatrixTy M = getMatrix(Op, Shape, Builder);
2157+
MatrixTy M = getMatrix(Op, SI, Builder);
21742158

21752159
Builder.setFastMathFlags(getFastMathFlags(Inst));
21762160

@@ -2184,7 +2168,7 @@ class LowerMatrixIntrinsics {
21842168
}
21852169
};
21862170

2187-
for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2171+
for (unsigned I = 0; I < SI.getNumVectors(); ++I)
21882172
Result.addVector(BuildVectorOp(M.getVector(I)));
21892173

21902174
finalizeLowering(Inst,

0 commit comments

Comments
 (0)