-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[Matrix] Use FixedVectorType everywhere in LowerMatrixIntrinsics. NFC #142316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…pass. NFC These matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes.
@llvm/pr-subscribers-llvm-transforms Author: Jon Roelofs (jroelofs) ChangesThese matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes. Full diff: https://github.com/llvm/llvm-project/pull/142316.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..787e107464c0a 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -383,25 +383,25 @@ class LowerMatrixIntrinsics {
return Vectors.size();
else {
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ return getVectorTy()->getNumElements();
}
}
unsigned getNumRows() const {
if (isColumnMajor()) {
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ return getVectorTy()->getNumElements();
} else
return Vectors.size();
}
void addVector(Value *V) { Vectors.push_back(V); }
- VectorType *getColumnTy() {
+ FixedVectorType *getColumnTy() {
assert(isColumnMajor() && "only supported for column-major matrixes");
return getVectorTy();
}
- VectorType *getVectorTy() const {
- return cast<VectorType>(Vectors[0]->getType());
+ FixedVectorType *getVectorTy() const {
+ return cast<FixedVectorType>(Vectors[0]->getType());
}
iterator_range<SmallVector<Value *, 8>::iterator> columns() {
@@ -514,7 +514,7 @@ class LowerMatrixIntrinsics {
: Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
unsigned getNumOps(Type *VT) {
- assert(isa<VectorType>(VT) && "Expected vector type");
+ assert(isa<FixedVectorType>(VT) && "Expected vector type");
return getNumOps(VT->getScalarType(),
cast<FixedVectorType>(VT)->getNumElements());
}
@@ -540,10 +540,8 @@ class LowerMatrixIntrinsics {
/// into vectors.
MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
IRBuilder<> &Builder) {
- VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
- assert(VType && "MatrixVal must be a vector type");
- assert(cast<FixedVectorType>(VType)->getNumElements() ==
- SI.NumRows * SI.NumColumns &&
+ FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
+ assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
"The vector size must match the number of matrix elements");
// Check if we lowered MatrixVal using shape information. In that case,
@@ -563,8 +561,7 @@ class LowerMatrixIntrinsics {
// Otherwise split MatrixVal.
SmallVector<Value *, 16> SplitVecs;
- for (unsigned MaskStart = 0;
- MaskStart < cast<FixedVectorType>(VType)->getNumElements();
+ for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
MaskStart += SI.getStride()) {
Value *V = Builder.CreateShuffleVector(
MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
@@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics {
/// vectors.
MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
- auto *VType = cast<VectorType>(Ty);
+ auto *VType = cast<FixedVectorType>(Ty);
Type *EltTy = VType->getElementType();
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
Value *EltPtr = Ptr;
@@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics {
MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
IRBuilder<> &Builder) {
- auto VType = cast<VectorType>(Ty);
+ auto *VType = cast<FixedVectorType>(Ty);
Value *EltPtr = Ptr;
for (auto Vec : enumerate(StoreVal.vectors())) {
Value *GEP = computeVectorAddr(
@@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics {
Value *LHS = MatMul->getArgOperand(0);
Value *RHS = MatMul->getArgOperand(1);
- Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
+ Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
bool IsIntVec = ElementType->isIntegerTy();
// Floating point reductions require reassocation.
@@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics {
int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
InstructionCost ReductionCost =
TTI.getArithmeticReductionCost(
- AddOpCode, cast<VectorType>(LHS->getType()),
+ AddOpCode, cast<FixedVectorType>(LHS->getType()),
IsIntVec ? std::nullopt : std::optional(FMF)) +
TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
InstructionCost SequentialAddCost =
@@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics {
Result = Builder.CreateAddReduce(Mul);
else {
Result = Builder.CreateFAddReduce(
- ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
- 0.0),
+ ConstantFP::get(
+ cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
Mul);
cast<Instruction>(Result)->setFastMathFlags(FMF);
}
@@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics {
const unsigned R = LShape.NumRows;
const unsigned C = RShape.NumColumns;
const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
const unsigned VF = std::max<unsigned>(
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
@@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics {
void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
// Create the main tiling loop nest.
TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
@@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics {
const unsigned R = LShape.NumRows;
const unsigned C = RShape.NumColumns;
const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
@@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics {
? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
: match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
IRBuilder<> Builder(MatMul);
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType =
+ cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
const unsigned R = LShape.NumRows;
@@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics {
/// Lowers llvm.matrix.multiply.
void LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
@@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics {
MatrixTy Result;
IRBuilder<> Builder(Inst);
Value *InputVal = Inst->getArgOperand(0);
- VectorType *VectorTy = cast<VectorType>(InputVal->getType());
+ FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/65/builds/17540 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/190/builds/20967 Here is the relevant piece of the build log for the reference
|
…llvm#142316) These matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes.
These matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes.