Skip to content

[Matrix] Pass ShapeInfo to Visit* methods (NFC). #142487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 5, 2025

Conversation

jroelofs
Copy link
Contributor

@jroelofs jroelofs commented Jun 2, 2025

They all require it now.

Visit* are always modifying the IR, remove the boolean result.

Co-authored by: Florian Hahn <florian_hahn@apple.com>
@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

Changes

Visit* are always modifying the IR, remove the boolean result.

Co-authored by: Florian Hahn <florian_hahn@apple.com>


Full diff: https://github.com/llvm/llvm-project/pull/142487.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+26-41)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index fb5e081acf7c5..124dc54b1dba8 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1056,19 +1056,24 @@ class LowerMatrixIntrinsics {
 
       IRBuilder<> Builder(Inst);
 
+      const ShapeInfo &SI = ShapeMap.at(Inst);
+
       if (CallInst *CInst = dyn_cast<CallInst>(Inst))
-        Changed |= VisitCallInst(CInst);
+        Changed |= tryVisitCallInst(CInst);
 
       Value *Op1;
       Value *Op2;
-      if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
-        Changed |= VisitBinaryOperator(BinOp);
-      if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
-        Changed |= VisitUnaryOperator(UnOp);
       if (match(Inst, m_Load(m_Value(Op1))))
-        Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+        VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
-        Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
+        VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
+      else if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
+        VisitBinaryOperator(BinOp, SI);
+      else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
+        VisitUnaryOperator(UnOp, SI);
+      else
+        continue;
+      Changed = true;
     }
 
     if (ORE) {
@@ -1107,7 +1112,7 @@ class LowerMatrixIntrinsics {
   }
 
   /// Replace intrinsic calls
-  bool VisitCallInst(CallInst *Inst) {
+  bool tryVisitCallInst(CallInst *Inst) {
     if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
       return false;
 
@@ -2105,49 +2110,36 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower load instructions, if shape information is available.
-  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
-    auto I = ShapeMap.find(Inst);
-    assert(I != ShapeMap.end() &&
-           "must only visit instructions with shape info");
+  void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, IRBuilder<> &Builder) {
     LowerLoad(Inst, Ptr, Inst->getAlign(),
-              Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
-              I->second);
-    return true;
+              Builder.getInt64(SI.getStride()), Inst->isVolatile(),
+              SI);
   }
 
-  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
+  void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, Value *Ptr,
                   IRBuilder<> &Builder) {
-    auto I = ShapeMap.find(StoredVal);
-    assert(I != ShapeMap.end() &&
-           "must only visit instructions with shape info");
     LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
-               Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
-               I->second);
-    return true;
+               Builder.getInt64(SI.getStride()), Inst->isVolatile(),
+               SI);
   }
 
   /// Lower binary operators, if shape information is available.
-  bool VisitBinaryOperator(BinaryOperator *Inst) {
-    auto I = ShapeMap.find(Inst);
-    assert(I != ShapeMap.end() &&
-           "must only visit instructions with shape info");
-
+  void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
     Value *Lhs = Inst->getOperand(0);
     Value *Rhs = Inst->getOperand(1);
 
     IRBuilder<> Builder(Inst);
-    ShapeInfo &Shape = I->second;
 
     MatrixTy Result;
-    MatrixTy A = getMatrix(Lhs, Shape, Builder);
-    MatrixTy B = getMatrix(Rhs, Shape, Builder);
+    MatrixTy A = getMatrix(Lhs, SI, Builder);
+    MatrixTy B = getMatrix(Rhs, SI, Builder);
     assert(A.isColumnMajor() == B.isColumnMajor() &&
            Result.isColumnMajor() == A.isColumnMajor() &&
            "operands must agree on matrix layout");
 
     Builder.setFastMathFlags(getFastMathFlags(Inst));
 
-    for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+    for (unsigned I = 0; I < SI.getNumVectors(); ++I)
       Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
                                            B.getVector(I)));
 
@@ -2155,22 +2147,16 @@ class LowerMatrixIntrinsics {
                      Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
                                              Result.getNumVectors()),
                      Builder);
-    return true;
   }
 
   /// Lower unary operators, if shape information is available.
-  bool VisitUnaryOperator(UnaryOperator *Inst) {
-    auto I = ShapeMap.find(Inst);
-    assert(I != ShapeMap.end() &&
-           "must only visit instructions with shape info");
-
+  void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
     Value *Op = Inst->getOperand(0);
 
     IRBuilder<> Builder(Inst);
-    ShapeInfo &Shape = I->second;
 
     MatrixTy Result;
-    MatrixTy M = getMatrix(Op, Shape, Builder);
+    MatrixTy M = getMatrix(Op, SI, Builder);
 
     Builder.setFastMathFlags(getFastMathFlags(Inst));
 
@@ -2184,14 +2170,13 @@ class LowerMatrixIntrinsics {
       }
     };
 
-    for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+    for (unsigned I = 0; I < SI.getNumVectors(); ++I)
       Result.addVector(BuildVectorOp(M.getVector(I)));
 
     finalizeLowering(Inst,
                      Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
                                              Result.getNumVectors()),
                      Builder);
-    return true;
   }
 
   /// Helper to linearize a matrix expression tree into a string. Currently

@jroelofs
Copy link
Contributor Author

jroelofs commented Jun 2, 2025

subsumes builds on top of #142487

Visit* are always modifying the IR, remove the boolean result.

Depends on llvm#142416.
@jroelofs jroelofs changed the title [Matrix] Don't update Changed based on Visit* return value (NFC). [Matrix] Pass ShapeInfo to Visit* methods(NFC). Jun 5, 2025
Copy link

github-actions bot commented Jun 5, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@jroelofs jroelofs changed the title [Matrix] Pass ShapeInfo to Visit* methods(NFC). [Matrix] Pass ShapeInfo to Visit* methods (NFC). Jun 5, 2025
@jroelofs jroelofs merged commit 7b2ac8f into llvm:main Jun 5, 2025
9 of 10 checks passed
@jroelofs jroelofs deleted the jroelofs/lower-matrix-retval branch June 5, 2025 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants