Skip to content

[Matrix] Use FixedVectorType everywhere in LowerMatrixIntrinsics. NFC #142316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,25 +383,25 @@ class LowerMatrixIntrinsics {
return Vectors.size();
else {
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
return getVectorTy()->getNumElements();
}
}
unsigned getNumRows() const {
if (isColumnMajor()) {
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
return getVectorTy()->getNumElements();
} else
return Vectors.size();
}

void addVector(Value *V) { Vectors.push_back(V); }
VectorType *getColumnTy() {
FixedVectorType *getColumnTy() {
assert(isColumnMajor() && "only supported for column-major matrixes");
return getVectorTy();
}

VectorType *getVectorTy() const {
return cast<VectorType>(Vectors[0]->getType());
FixedVectorType *getVectorTy() const {
return cast<FixedVectorType>(Vectors[0]->getType());
}

iterator_range<SmallVector<Value *, 8>::iterator> columns() {
Expand Down Expand Up @@ -514,7 +514,7 @@ class LowerMatrixIntrinsics {
: Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}

unsigned getNumOps(Type *VT) {
assert(isa<VectorType>(VT) && "Expected vector type");
assert(isa<FixedVectorType>(VT) && "Expected vector type");
return getNumOps(VT->getScalarType(),
cast<FixedVectorType>(VT)->getNumElements());
}
Expand All @@ -540,10 +540,8 @@ class LowerMatrixIntrinsics {
/// into vectors.
MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
IRBuilder<> &Builder) {
VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
assert(VType && "MatrixVal must be a vector type");
assert(cast<FixedVectorType>(VType)->getNumElements() ==
SI.NumRows * SI.NumColumns &&
FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
"The vector size must match the number of matrix elements");

// Check if we lowered MatrixVal using shape information. In that case,
Expand All @@ -563,8 +561,7 @@ class LowerMatrixIntrinsics {

// Otherwise split MatrixVal.
SmallVector<Value *, 16> SplitVecs;
for (unsigned MaskStart = 0;
MaskStart < cast<FixedVectorType>(VType)->getNumElements();
for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
MaskStart += SI.getStride()) {
Value *V = Builder.CreateShuffleVector(
MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
Expand Down Expand Up @@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics {
/// vectors.
MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
auto *VType = cast<VectorType>(Ty);
auto *VType = cast<FixedVectorType>(Ty);
Type *EltTy = VType->getElementType();
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
Value *EltPtr = Ptr;
Expand Down Expand Up @@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics {
MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
IRBuilder<> &Builder) {
auto VType = cast<VectorType>(Ty);
auto *VType = cast<FixedVectorType>(Ty);
Value *EltPtr = Ptr;
for (auto Vec : enumerate(StoreVal.vectors())) {
Value *GEP = computeVectorAddr(
Expand Down Expand Up @@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics {
Value *LHS = MatMul->getArgOperand(0);
Value *RHS = MatMul->getArgOperand(1);

Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
bool IsIntVec = ElementType->isIntegerTy();

// Floating point reductions require reassocation.
Expand Down Expand Up @@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics {
int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
InstructionCost ReductionCost =
TTI.getArithmeticReductionCost(
AddOpCode, cast<VectorType>(LHS->getType()),
AddOpCode, cast<FixedVectorType>(LHS->getType()),
IsIntVec ? std::nullopt : std::optional(FMF)) +
TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
InstructionCost SequentialAddCost =
Expand Down Expand Up @@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics {
Result = Builder.CreateAddReduce(Mul);
else {
Result = Builder.CreateFAddReduce(
ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
0.0),
ConstantFP::get(
cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
Mul);
cast<Instruction>(Result)->setFastMathFlags(FMF);
}
Expand Down Expand Up @@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics {
const unsigned R = LShape.NumRows;
const unsigned C = RShape.NumColumns;
const unsigned M = LShape.NumColumns;
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();

const unsigned VF = std::max<unsigned>(
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
Expand Down Expand Up @@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics {

void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();

// Create the main tiling loop nest.
TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
Expand Down Expand Up @@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics {
const unsigned R = LShape.NumRows;
const unsigned C = RShape.NumColumns;
const unsigned M = LShape.NumColumns;
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();

Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
Expand Down Expand Up @@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics {
? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
: match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
IRBuilder<> Builder(MatMul);
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
auto *EltType =
cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
const unsigned R = LShape.NumRows;
Expand Down Expand Up @@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics {
/// Lowers llvm.matrix.multiply.
void LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));

Expand Down Expand Up @@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics {
MatrixTy Result;
IRBuilder<> Builder(Inst);
Value *InputVal = Inst->getArgOperand(0);
VectorType *VectorTy = cast<VectorType>(InputVal->getType());
FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);

Expand Down
Loading