Skip to content

[Matrix] Propagate shape information through PHI insts #141681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

jroelofs
Copy link
Contributor

@jroelofs jroelofs commented May 27, 2025

... and split them as we lower them, avoiding several shuffles in the process.

... and split them as we lower themm, avoiding several shuffles in the process.
@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

Changes

... and split them as we lower themm, avoiding several shuffles in the process.


Patch is 42.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141681.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+92-1)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll (+216)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll (+58-65)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 56d4be513ea6f..c06d08688ab1c 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -30,6 +30,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
@@ -230,6 +231,7 @@ static bool isUniformShape(Value *V) {
     return true;
 
   switch (I->getOpcode()) {
+  case Instruction::PHI:
   case Instruction::FAdd:
   case Instruction::FSub:
   case Instruction::FMul: // Scalar multiply.
@@ -360,6 +362,33 @@ class LowerMatrixIntrinsics {
         addVector(PoisonValue::get(FixedVectorType::get(
             EltTy, isColumnMajor() ? NumRows : NumColumns)));
     }
+    MatrixTy(ConstantData *Constant, const ShapeInfo &SI)
+        : IsColumnMajor(SI.IsColumnMajor) {
+      Type *EltTy = cast<VectorType>(Constant->getType())->getElementType();
+      Type *RowTy = VectorType::get(EltTy, ElementCount::getFixed(SI.NumRows));
+
+      for (unsigned J = 0, D = SI.getNumVectors(); J < D; ++J) {
+        if (auto *CDV = dyn_cast<ConstantDataVector>(Constant)) {
+          unsigned Width = SI.getStride();
+          size_t EltSize = EltTy->getPrimitiveSizeInBits() / 8;
+          StringRef Data = CDV->getRawDataValues().substr(
+              J * Width * EltSize, Width * EltSize);
+          addVector(ConstantDataVector::getRaw(Data, Width,
+                                               CDV->getElementType()));
+        } else if (isa<PoisonValue>(Constant))
+          addVector(PoisonValue::get(RowTy));
+        else if (isa<UndefValue>(Constant))
+          addVector(UndefValue::get(RowTy));
+        else if (isa<ConstantAggregateZero>(Constant))
+          addVector(ConstantAggregateZero::get(RowTy));
+        else {
+#ifndef NDEBUG
+          Constant->dump();
+          report_fatal_error("unhandled ConstantData type");
+#endif
+        }
+      }
+    }
 
     Value *getVector(unsigned i) const { return Vectors[i]; }
     Value *getColumn(unsigned i) const {
@@ -564,6 +593,27 @@ class LowerMatrixIntrinsics {
       MatrixVal = M.embedInVector(Builder);
     }
 
+    // If it's a PHI, split it now. We'll take care of fixing up the operands
+    // later once we're in VisitPHI.
+    if (auto *PHI = dyn_cast<PHINode>(MatrixVal)) {
+      auto *EltTy = cast<VectorType>(PHI->getType())->getElementType();
+      MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy};
+
+      IRBuilder<>::InsertPointGuard IPG(Builder);
+      Builder.SetInsertPoint(PHI);
+      for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
+        PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
+                                             PHI->getNumIncomingValues(),
+                                             PHI->getName()));
+
+      Inst2ColumnMatrix[PHI] = PhiM;
+      return PhiM;
+    }
+
+    // If it's a constant, materialize the split version of it with this shape.
+    if (auto *IncomingConst = dyn_cast<ConstantData>(MatrixVal))
+      return MatrixTy(IncomingConst, SI);
+
     // Otherwise split MatrixVal.
     SmallVector<Value *, 16> SplitVecs;
     for (unsigned MaskStart = 0;
@@ -1077,6 +1127,11 @@ class LowerMatrixIntrinsics {
         Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
     }
 
+    // Fifth, lower all the PHI's with shape information.
+    for (Instruction *Inst : MatrixInsts)
+      if (auto *PHI = dyn_cast<PHINode>(Inst))
+        Changed |= VisitPHI(PHI);
+
     if (ORE) {
       RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
       RemarkGen.emitRemarks();
@@ -1349,7 +1404,8 @@ class LowerMatrixIntrinsics {
                         IRBuilder<> &Builder) {
     auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
     (void)inserted;
-    assert(inserted.second && "multiple matrix lowering mapping");
+    assert((inserted.second || isa<PHINode>(Inst)) &&
+           "multiple matrix lowering mapping");
 
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
@@ -2133,6 +2189,41 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  bool VisitPHI(PHINode *Inst) {
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    IRBuilder<> Builder(Inst);
+
+    MatrixTy PhiM = getMatrix(Inst, I->second, Builder);
+
+    for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues();
+         IncomingI != IncomingE; ++IncomingI) {
+      Value *IncomingV = Inst->getIncomingValue(IncomingI);
+      BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI);
+
+      // getMatrix() may insert some instructions. The safe place to insert them
+      // is at the end of the parent block, where the register allocator would
+      // have inserted the copies that materialize the PHI.
+      if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
+        Builder.SetInsertPoint(IncomingInst->getParent()->getTerminator());
+
+      MatrixTy OpM = getMatrix(IncomingV, I->second, Builder);
+
+      for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
+        PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
+        NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
+      }
+    }
+
+    // finalizeLowering() may also insert instructions in some cases. The safe
+    // place for those is at the end of the initial block of PHIs.
+    Builder.SetInsertPoint(*Inst->getInsertionPointAfterDef());
+    finalizeLowering(Inst, PhiM, Builder);
+    return true;
+  }
+
   /// Lower binary operators, if shape information is available.
   bool VisitBinaryOperator(BinaryOperator *Inst) {
     auto I = ShapeMap.find(Inst);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll
new file mode 100644
index 0000000000000..d49b4d1112062
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll
@@ -0,0 +1,216 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -matrix-allow-contract=false -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define void @matrix_phi(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN1]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI9:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI10:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI11:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI9]], [[COL_LOAD4]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI10]], [[COL_LOAD6]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI11]], [[COL_LOAD8]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP12]], align 8
+; CHECK-NEXT:    [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP13]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  %mat = load <9 x double>, ptr %in1
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [%mat, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
+
+define void @matrix_phi_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi_zeroinitializer(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI4:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI5:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI6:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [zeroinitializer, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
+
+define void @matrix_phi_undef(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi_undef(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI4:%.*]] = phi <3 x double> [ undef, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI5:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI6:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [undef, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
+
+define void @matrix_phi_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi_poison(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI4:%.*]] = phi <3 x double> [ poison, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI5:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI6:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [poison, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll
index 2af2c979f2065..6ed8e46d62892 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll
@@ -28,9 +28,6 @@ define <9 x double> @unsupported_phi(i1 %cond, <9 x double> %A, <9 x double> %B,
 ; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <3 x double> [[TMP13]], double [[TMP14]], i64 1
 ; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 2
 ; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <3 x double> [[TMP15]], double [[TMP16]], i64 2
-; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <3 x double> [[TMP5]], <3 x double> [[TMP11]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP19:%.*]] = shufflevector <3 x double> [[TMP17]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <6 x double> [[TMP18]], <6 x double> [[TMP19]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
 ; CHECK-NEXT:    br label [[IF_END:%.*]]
 ; CHECK:       if.else:
 ; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <9 x double> [[B:%.*]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
@@ -54,183 +51,179 @@ define <9 x double> @unsupported_phi(i1 %cond, <9 x double> %A, <9 x double> %B,
 ; CHECK-NEXT:    [[TMP36:%.*]] = insertelement <3 x double> [[TMP34]], double [[TMP35]], i64 1
 ; CHECK-NEXT:    [[TMP37:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 2
 ; CHECK-NEXT:    [[TMP38:%.*]] = insertelement <3 x double> [[TMP36]], double [[TMP37]], i64 2
-; CHECK-NEXT:    [[TMP39:%.*]] = shufflevector <3 x double> [[TMP26]], <3 x double> [[TMP32]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP40:%.*]] = shufflevector <3 x double> [[TMP38]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP41:%.*]] = shufflevector <6 x double> [[TMP39]], <6 x double> [[TMP40]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
 ; CHECK-NEXT:    br label [[IF_END]]
 ; CHECK:       if.end:
-; CHECK-NEXT:    [[MERGE:%.*]] = phi <9 x double> [ [[TMP20]], [[IF_THEN]] ], [ [[TMP41]], [[IF_ELSE]] ]
-; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <9 x double> [[C:%.*]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <9 x double> [[C]], <9 x double> poison, <3 x i32> <i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <9 x double> [[C]], <9 x double> poison, <3 x i32> <i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; CHECK-NEXT:    [[MERGE9:%.*]] = phi <3 x double> [ [[TMP5]], [[IF_THEN]] ], [ [[TMP26]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[MERGE10:%.*]] = phi <3 x double> [ [[TMP11]], [[IF_THEN]] ], [ [[TMP32]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[MERGE11:%.*]] = phi <3 x double> [ [[TMP17]], [[IF_THEN]] ], [ [[TMP38]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE:%.*]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
 ; CHECK-NEXT:    [[SPLIT10:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> <i32 3, i32 4, i32 5>
 ; CHECK-NEXT:    [[SPLIT11:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> <i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP42:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 0
+; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[T...
[truncated]

Copy link

github-actions bot commented May 27, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 00ba8651e..ec9ddb6e2 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -2184,8 +2184,8 @@ public:
           Builder.SetInsertPoint(PHI);
           for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
             PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
-                                                PHI->getNumIncomingValues(),
-                                                PHI->getName()));
+                                                 PHI->getNumIncomingValues(),
+                                                 PHI->getName()));
 
           Inst2ColumnMatrix[PHI] = PhiM;
         }

Copy link

⚠️ undef deprecator found issues in your code. ⚠️

You can test this locally with the following command:
git diff -U0 --pickaxe-regex -S '([^a-zA-Z0-9#_-]undef[^a-zA-Z0-9_-]|UndefValue::get)' 'HEAD~1' HEAD llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll

The following files introduce new uses of undef:

  • llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
  • llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll

Undef is now deprecated and should only be used in the rare cases where no replacement is possible. For example, a load of uninitialized memory yields undef. You should use poison values for placeholders instead.

In tests, avoid using undef and having tests that trigger undefined behavior. If you need an operand with some unimportant value, you can add a new argument to the function and use that instead.

For example, this is considered a bad practice:

define void @fn() {
  ...
  br i1 undef, ...
}

Please use the following instead:

define void @fn(i1 %cond) {
  ...
  br i1 %cond, ...
}

Please refer to the Undefined Behavior Manual for more information.

@jroelofs jroelofs changed the title [Matrix] Propagate shape information through PHI instructions [Matrix] Propagate shape information through PHI insts May 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants