@@ -1295,6 +1295,24 @@ class LowerMatrixIntrinsics {
12951295 return commonAlignment (InitialAlign, ElementSizeInBits / 8 );
12961296 }
12971297
1298+ IntegerType *getIndexType (Value *Ptr) const {
1299+ return cast<IntegerType>(DL.getIndexType (Ptr->getType ()));
1300+ }
1301+
1302+ Value *getIndex (Value *Ptr, uint64_t V) const {
1303+ return ConstantInt::get (getIndexType (Ptr), V);
1304+ }
1305+
1306+ Value *castToIndexType (Value *Ptr, Value *V, IRBuilder<> &Builder) const {
1307+ assert (isa<IntegerType>(V->getType ()) &&
1308+ " Attempted to cast non-integral type to integer index" );
1309+ // In case the data layout's index type differs in width from the type of
1310+ // the value we're given, truncate or zero extend to the appropriate width.
1311+ // We zero extend here as indices are unsigned.
1312+ return Builder.CreateZExtOrTrunc (V, getIndexType (Ptr),
1313+ V->getName () + " .cast" );
1314+ }
1315+
12981316 // / Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
12991317 // / vectors.
13001318 MatrixTy loadMatrix (Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
@@ -1304,6 +1322,7 @@ class LowerMatrixIntrinsics {
13041322 Type *VecTy = FixedVectorType::get (EltTy, Shape.getStride ());
13051323 Value *EltPtr = Ptr;
13061324 MatrixTy Result;
1325+ Stride = castToIndexType (Ptr, Stride, Builder);
13071326 for (unsigned I = 0 , E = Shape.getNumVectors (); I < E; ++I) {
13081327 Value *GEP = computeVectorAddr (
13091328 EltPtr, Builder.getIntN (Stride->getType ()->getScalarSizeInBits (), I),
@@ -1325,14 +1344,14 @@ class LowerMatrixIntrinsics {
13251344 ShapeInfo ResultShape, Type *EltTy,
13261345 IRBuilder<> &Builder) {
13271346 Value *Offset = Builder.CreateAdd (
1328- Builder.CreateMul (J, Builder. getInt64 ( MatrixShape.getStride ())), I);
1347+ Builder.CreateMul (J, getIndex (MatrixPtr, MatrixShape.getStride ())), I);
13291348
13301349 Value *TileStart = Builder.CreateGEP (EltTy, MatrixPtr, Offset);
13311350 auto *TileTy = FixedVectorType::get (EltTy, ResultShape.NumRows *
13321351 ResultShape.NumColumns );
13331352
13341353 return loadMatrix (TileTy, TileStart, Align,
1335- Builder. getInt64 ( MatrixShape.getStride ()), IsVolatile,
1354+ getIndex (MatrixPtr, MatrixShape.getStride ()), IsVolatile,
13361355 ResultShape, Builder);
13371356 }
13381357
@@ -1363,14 +1382,15 @@ class LowerMatrixIntrinsics {
13631382 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
13641383 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
13651384 Value *Offset = Builder.CreateAdd (
1366- Builder.CreateMul (J, Builder. getInt64 ( MatrixShape.getStride ())), I);
1385+ Builder.CreateMul (J, getIndex (MatrixPtr, MatrixShape.getStride ())), I);
13671386
13681387 Value *TileStart = Builder.CreateGEP (EltTy, MatrixPtr, Offset);
13691388 auto *TileTy = FixedVectorType::get (EltTy, StoreVal.getNumRows () *
13701389 StoreVal.getNumColumns ());
13711390
13721391 storeMatrix (TileTy, StoreVal, TileStart, MAlign,
1373- Builder.getInt64 (MatrixShape.getStride ()), IsVolatile, Builder);
1392+ getIndex (MatrixPtr, MatrixShape.getStride ()), IsVolatile,
1393+ Builder);
13741394 }
13751395
13761396 // / Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
@@ -1380,6 +1400,7 @@ class LowerMatrixIntrinsics {
13801400 IRBuilder<> &Builder) {
13811401 auto *VType = cast<FixedVectorType>(Ty);
13821402 Value *EltPtr = Ptr;
1403+ Stride = castToIndexType (Ptr, Stride, Builder);
13831404 for (auto Vec : enumerate(StoreVal.vectors ())) {
13841405 Value *GEP = computeVectorAddr (
13851406 EltPtr,
@@ -2011,18 +2032,17 @@ class LowerMatrixIntrinsics {
20112032 const unsigned TileM = std::min (M - K, unsigned (TileSize));
20122033 MatrixTy A =
20132034 loadMatrix (APtr, LoadOp0->getAlign (), LoadOp0->isVolatile (),
2014- LShape, Builder. getInt64 ( I), Builder. getInt64 ( K),
2035+ LShape, getIndex (APtr, I), getIndex (APtr, K),
20152036 {TileR, TileM}, EltType, Builder);
20162037 MatrixTy B =
20172038 loadMatrix (BPtr, LoadOp1->getAlign (), LoadOp1->isVolatile (),
2018- RShape, Builder. getInt64 ( K), Builder. getInt64 ( J),
2039+ RShape, getIndex (BPtr, K), getIndex (BPtr, J),
20192040 {TileM, TileC}, EltType, Builder);
20202041 emitMatrixMultiply (Res, A, B, Builder, true , false ,
20212042 getFastMathFlags (MatMul));
20222043 }
20232044 storeMatrix (Res, CPtr, Store->getAlign (), Store->isVolatile (), {R, M},
2024- Builder.getInt64 (I), Builder.getInt64 (J), EltType,
2025- Builder);
2045+ getIndex (CPtr, I), getIndex (CPtr, J), EltType, Builder);
20262046 }
20272047 }
20282048
@@ -2254,15 +2274,14 @@ class LowerMatrixIntrinsics {
22542274 // / Lower load instructions.
22552275 MatrixTy VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
22562276 IRBuilder<> &Builder) {
2257- return LowerLoad (Inst, Ptr, Inst->getAlign (),
2258- Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI,
2259- Builder);
2277+ return LowerLoad (Inst, Ptr, Inst->getAlign (), getIndex (Ptr, SI.getStride ()),
2278+ Inst->isVolatile (), SI, Builder);
22602279 }
22612280
22622281 MatrixTy VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
22632282 Value *Ptr, IRBuilder<> &Builder) {
22642283 return LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2265- Builder. getInt64 ( SI.getStride ()), Inst->isVolatile (), SI,
2284+ getIndex (Ptr, SI.getStride ()), Inst->isVolatile (), SI,
22662285 Builder);
22672286 }
22682287
0 commit comments