@@ -1056,19 +1056,20 @@ class LowerMatrixIntrinsics {
1056
1056
1057
1057
IRBuilder<> Builder (Inst);
1058
1058
1059
- if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1060
- Changed |= VisitCallInst (CInst);
1059
+ const ShapeInfo &SI = ShapeMap.at (Inst);
1061
1060
1062
1061
Value *Op1;
1063
1062
Value *Op2;
1064
1063
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1065
- VisitBinaryOperator (BinOp);
1064
+ VisitBinaryOperator (BinOp, SI );
1066
1065
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1067
- VisitUnaryOperator (UnOp);
1066
+ VisitUnaryOperator (UnOp, SI);
1067
+ else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1068
+ VisitCallInst (CInst);
1068
1069
else if (match (Inst, m_Load (m_Value (Op1))))
1069
- VisitLoad (cast<LoadInst>(Inst), Op1, Builder);
1070
+ VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder);
1070
1071
else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1071
- VisitStore (cast<StoreInst>(Inst), Op1, Op2, Builder);
1072
+ VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1072
1073
else
1073
1074
continue ;
1074
1075
Changed = true ;
@@ -1109,10 +1110,10 @@ class LowerMatrixIntrinsics {
1109
1110
return Changed;
1110
1111
}
1111
1112
1112
- // / Replace intrinsic calls
1113
- bool VisitCallInst (CallInst *Inst) {
1114
- if (! Inst->getCalledFunction () || !Inst-> getCalledFunction ()-> isIntrinsic ())
1115
- return false ;
1113
+ // / Replace intrinsic calls.
1114
+ void VisitCallInst (CallInst *Inst) {
1115
+ assert ( Inst->getCalledFunction () &&
1116
+ Inst-> getCalledFunction ()-> isIntrinsic ()) ;
1116
1117
1117
1118
switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
1118
1119
case Intrinsic::matrix_multiply:
@@ -1128,9 +1129,9 @@ class LowerMatrixIntrinsics {
1128
1129
LowerColumnMajorStore (Inst);
1129
1130
break ;
1130
1131
default :
1131
- return false ;
1132
+ llvm_unreachable (
1133
+ " only intrinsics supporting shape info should be seen here" );
1132
1134
}
1133
- return true ;
1134
1135
}
1135
1136
1136
1137
// / Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -2107,48 +2108,36 @@ class LowerMatrixIntrinsics {
2107
2108
Builder);
2108
2109
}
2109
2110
2110
- // / Lower load instructions, if shape information is available.
2111
- void VisitLoad (LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2112
- auto I = ShapeMap.find (Inst);
2113
- assert (I != ShapeMap.end () &&
2114
- " must only visit instructions with shape info" );
2115
- LowerLoad (Inst, Ptr, Inst->getAlign (),
2116
- Builder.getInt64 (I->second .getStride ()), Inst->isVolatile (),
2117
- I->second );
2111
+ // / Lower load instructions.
2112
+ void VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2113
+ IRBuilder<> &Builder) {
2114
+ LowerLoad (Inst, Ptr, Inst->getAlign (), Builder.getInt64 (SI.getStride ()),
2115
+ Inst->isVolatile (), SI);
2118
2116
}
2119
2117
2120
- void VisitStore (StoreInst *Inst, Value *StoredVal, Value *Ptr,
2121
- IRBuilder<> &Builder) {
2122
- auto I = ShapeMap.find (Inst);
2123
- assert (I != ShapeMap.end () &&
2124
- " must only visit instructions with shape info" );
2118
+ void VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2119
+ Value *Ptr, IRBuilder<> &Builder) {
2125
2120
LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2126
- Builder.getInt64 (I->second .getStride ()), Inst->isVolatile (),
2127
- I->second );
2121
+ Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
2128
2122
}
2129
2123
2130
- // / Lower binary operators, if shape information is available.
2131
- void VisitBinaryOperator (BinaryOperator *Inst) {
2132
- auto I = ShapeMap.find (Inst);
2133
- assert (I != ShapeMap.end () &&
2134
- " must only visit instructions with shape info" );
2135
-
2124
+ // / Lower binary operators.
2125
+ void VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
2136
2126
Value *Lhs = Inst->getOperand (0 );
2137
2127
Value *Rhs = Inst->getOperand (1 );
2138
2128
2139
2129
IRBuilder<> Builder (Inst);
2140
- ShapeInfo &Shape = I->second ;
2141
2130
2142
2131
MatrixTy Result;
2143
- MatrixTy A = getMatrix (Lhs, Shape , Builder);
2144
- MatrixTy B = getMatrix (Rhs, Shape , Builder);
2132
+ MatrixTy A = getMatrix (Lhs, SI , Builder);
2133
+ MatrixTy B = getMatrix (Rhs, SI , Builder);
2145
2134
assert (A.isColumnMajor () == B.isColumnMajor () &&
2146
2135
Result.isColumnMajor () == A.isColumnMajor () &&
2147
2136
" operands must agree on matrix layout" );
2148
2137
2149
2138
Builder.setFastMathFlags (getFastMathFlags (Inst));
2150
2139
2151
- for (unsigned I = 0 ; I < Shape .getNumVectors (); ++I)
2140
+ for (unsigned I = 0 ; I < SI .getNumVectors (); ++I)
2152
2141
Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
2153
2142
B.getVector (I)));
2154
2143
@@ -2158,19 +2147,14 @@ class LowerMatrixIntrinsics {
2158
2147
Builder);
2159
2148
}
2160
2149
2161
- // / Lower unary operators, if shape information is available.
2162
- void VisitUnaryOperator (UnaryOperator *Inst) {
2163
- auto I = ShapeMap.find (Inst);
2164
- assert (I != ShapeMap.end () &&
2165
- " must only visit instructions with shape info" );
2166
-
2150
+ // / Lower unary operators.
2151
+ void VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
2167
2152
Value *Op = Inst->getOperand (0 );
2168
2153
2169
2154
IRBuilder<> Builder (Inst);
2170
- ShapeInfo &Shape = I->second ;
2171
2155
2172
2156
MatrixTy Result;
2173
- MatrixTy M = getMatrix (Op, Shape , Builder);
2157
+ MatrixTy M = getMatrix (Op, SI , Builder);
2174
2158
2175
2159
Builder.setFastMathFlags (getFastMathFlags (Inst));
2176
2160
@@ -2184,7 +2168,7 @@ class LowerMatrixIntrinsics {
2184
2168
}
2185
2169
};
2186
2170
2187
- for (unsigned I = 0 ; I < Shape .getNumVectors (); ++I)
2171
+ for (unsigned I = 0 ; I < SI .getNumVectors (); ++I)
2188
2172
Result.addVector (BuildVectorOp (M.getVector (I)));
2189
2173
2190
2174
finalizeLowering (Inst,
0 commit comments