-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[Matrix] Pass ShapeInfo to Visit* methods (NFC). #142487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Visit* are always modifying the IR, remove the boolean result. Co-authored by: Florian Hahn <florian_hahn@apple.com>
@llvm/pr-subscribers-llvm-transforms Author: Jon Roelofs (jroelofs) ChangesVisit* are always modifying the IR, remove the boolean result. Co-authored by: Florian Hahn <florian_hahn@apple.com> Full diff: https://github.com/llvm/llvm-project/pull/142487.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index fb5e081acf7c5..124dc54b1dba8 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1056,19 +1056,24 @@ class LowerMatrixIntrinsics {
IRBuilder<> Builder(Inst);
+ const ShapeInfo &SI = ShapeMap.at(Inst);
+
if (CallInst *CInst = dyn_cast<CallInst>(Inst))
- Changed |= VisitCallInst(CInst);
+ Changed |= tryVisitCallInst(CInst);
Value *Op1;
Value *Op2;
- if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
- Changed |= VisitBinaryOperator(BinOp);
- if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
- Changed |= VisitUnaryOperator(UnOp);
if (match(Inst, m_Load(m_Value(Op1))))
- Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+ VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
+ VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
+ else if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
+ VisitBinaryOperator(BinOp, SI);
+ else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
+ VisitUnaryOperator(UnOp, SI);
+ else
+ continue;
+ Changed = true;
}
if (ORE) {
@@ -1107,7 +1112,7 @@ class LowerMatrixIntrinsics {
}
/// Replace intrinsic calls
- bool VisitCallInst(CallInst *Inst) {
+ bool tryVisitCallInst(CallInst *Inst) {
if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
return false;
@@ -2105,49 +2110,36 @@ class LowerMatrixIntrinsics {
}
/// Lower load instructions, if shape information is available.
- bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
+ void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, IRBuilder<> &Builder) {
LowerLoad(Inst, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(),
+ SI);
}
- bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
+ void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, Value *Ptr,
IRBuilder<> &Builder) {
- auto I = ShapeMap.find(StoredVal);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(),
+ SI);
}
/// Lower binary operators, if shape information is available.
- bool VisitBinaryOperator(BinaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
-
+ void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
Value *Lhs = Inst->getOperand(0);
Value *Rhs = Inst->getOperand(1);
IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
MatrixTy Result;
- MatrixTy A = getMatrix(Lhs, Shape, Builder);
- MatrixTy B = getMatrix(Rhs, Shape, Builder);
+ MatrixTy A = getMatrix(Lhs, SI, Builder);
+ MatrixTy B = getMatrix(Rhs, SI, Builder);
assert(A.isColumnMajor() == B.isColumnMajor() &&
Result.isColumnMajor() == A.isColumnMajor() &&
"operands must agree on matrix layout");
Builder.setFastMathFlags(getFastMathFlags(Inst));
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
B.getVector(I)));
@@ -2155,22 +2147,16 @@ class LowerMatrixIntrinsics {
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
- return true;
}
/// Lower unary operators, if shape information is available.
- bool VisitUnaryOperator(UnaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
-
+ void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
Value *Op = Inst->getOperand(0);
IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
MatrixTy Result;
- MatrixTy M = getMatrix(Op, Shape, Builder);
+ MatrixTy M = getMatrix(Op, SI, Builder);
Builder.setFastMathFlags(getFastMathFlags(Inst));
@@ -2184,14 +2170,13 @@ class LowerMatrixIntrinsics {
}
};
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(BuildVectorOp(M.getVector(I)));
finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
- return true;
}
/// Helper to linearize a matrix expression tree into a string. Currently
|
|
Visit* are always modifying the IR, remove the boolean result. Depends on llvm#142416.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
They all require it now.