@@ -1152,8 +1152,78 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
11521152 }
11531153};
11541154
1155- // Drop inner most contiguous unit dimensions from transfer_read operand.
1156- class DropInnerMostUnitDims : public OpRewritePattern <vector::TransferReadOp> {
1155+ // / Returns the number of dims can be folded away from transfer ops. It returns
1156+ // / a failure if it can not determine the number of dims to be folded.
1157+ // / Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
1158+ // / `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
1159+ // / can be dropped by memref.subview ops.
1160+ // / Example 2: it returns "1" if `srcType` is the same memref type with
1161+ // / [8192, 16, 8, 1] strides.
1162+ static FailureOr<size_t >
1163+ getTransferFoldableInnerUnitDims (MemRefType srcType, VectorType vectorType) {
1164+ SmallVector<int64_t > srcStrides;
1165+ int64_t srcOffset;
1166+ if (failed (getStridesAndOffset (srcType, srcStrides, srcOffset)))
1167+ return failure ();
1168+
1169+ // According to vector.transfer_read/write semantics, the vector can be a
1170+ // slice. Thus, we have to offset the check index with `rankDiff` in
1171+ // `srcStrides` and source dim sizes.
1172+ size_t result = 0 ;
1173+ int rankDiff = srcType.getRank () - vectorType.getRank ();
1174+ for (int64_t i = 0 , e = vectorType.getRank (); i < e; ++i) {
1175+ // Check that the inner dim size is 1 for both memref type and vector slice.
1176+ // It can be folded only if they are 1 and the stride is 1.
1177+ int dim = vectorType.getRank () - i - 1 ;
1178+ if (srcStrides[dim + rankDiff] != 1 ||
1179+ srcType.getDimSize (dim + rankDiff) != 1 ||
1180+ vectorType.getDimSize (dim) != 1 )
1181+ break ;
1182+ result++;
1183+ }
1184+ return result;
1185+ }
1186+
1187+ // / Returns a MemRef type that drops inner `dimsToDrop` dimensions from
1188+ // / `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
1189+ // / two, it returns memref<512x16x16> type.
1190+ static MemRefType getMemRefTypeWithDroppingInnerDims (OpBuilder &builder,
1191+ MemRefType srcType,
1192+ size_t dimsToDrop) {
1193+ MemRefType resultMemrefType;
1194+ MemRefLayoutAttrInterface layout = srcType.getLayout ();
1195+ if (isa<AffineMapAttr>(layout) && layout.isIdentity ()) {
1196+ return MemRefType::get (srcType.getShape ().drop_back (dimsToDrop),
1197+ srcType.getElementType (), nullptr ,
1198+ srcType.getMemorySpace ());
1199+ }
1200+ MemRefLayoutAttrInterface updatedLayout;
1201+ if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1202+ auto strides = llvm::to_vector (strided.getStrides ().drop_back (dimsToDrop));
1203+ updatedLayout = StridedLayoutAttr::get (strided.getContext (),
1204+ strided.getOffset (), strides);
1205+ return MemRefType::get (srcType.getShape ().drop_back (dimsToDrop),
1206+ srcType.getElementType (), updatedLayout,
1207+ srcType.getMemorySpace ());
1208+ }
1209+
1210+ // Non-strided layout case.
1211+ AffineMap map = srcType.getLayout ().getAffineMap ();
1212+ int numSymbols = map.getNumSymbols ();
1213+ for (size_t i = 0 ; i < dimsToDrop; ++i) {
1214+ int dim = srcType.getRank () - i - 1 ;
1215+ map = map.replace (builder.getAffineDimExpr (dim),
1216+ builder.getAffineConstantExpr (0 ), map.getNumDims () - 1 ,
1217+ numSymbols);
1218+ }
1219+ return MemRefType::get (srcType.getShape ().drop_back (dimsToDrop),
1220+ srcType.getElementType (), updatedLayout,
1221+ srcType.getMemorySpace ());
1222+ }
1223+
1224+ // / Drop inner most contiguous unit dimensions from transfer_read operand.
1225+ class DropInnerMostUnitDimsTransferRead
1226+ : public OpRewritePattern<vector::TransferReadOp> {
11571227 using OpRewritePattern::OpRewritePattern;
11581228
11591229 LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
@@ -1177,65 +1247,22 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
11771247 if (targetType.getRank () <= 1 )
11781248 return failure ();
11791249
1180- SmallVector<int64_t > srcStrides;
1181- int64_t srcOffset;
1182- if (failed (getStridesAndOffset (srcType, srcStrides, srcOffset)))
1183- return failure ();
1184-
1185- // According to vector.transfer_read semantics, the result can be a slice.
1186- // It pads the indices with `1` starting from beginning. Thus, we have to
1187- // offset the check index with `rankDiff` in `srcStrides` and source dim
1188- // sizes.
1189- size_t dimsToDrop = 0 ;
1190- int rankDiff = srcType.getRank () - targetType.getRank ();
1191- for (int64_t i = 0 , e = targetType.getRank (); i < e; ++i) {
1192- // Check that the inner dim size is 1 for both memref/tensor type and
1193- // vector slice. It can be folded only if they are 1 and the stride is 1.
1194- int dim = targetType.getRank () - i - 1 ;
1195- if (srcStrides[dim + rankDiff] == 1 &&
1196- srcType.getDimSize (dim + rankDiff) == 1 &&
1197- targetType.getDimSize (dim) == 1 ) {
1198- dimsToDrop++;
1199- } else {
1200- break ;
1201- }
1202- }
1250+ FailureOr<size_t > maybeDimsToDrop =
1251+ getTransferFoldableInnerUnitDims (srcType, targetType);
1252+ if (failed (maybeDimsToDrop))
1253+ return failure ();
1254+
1255+ size_t dimsToDrop = maybeDimsToDrop.value ();
12031256 if (dimsToDrop == 0 )
12041257 return failure ();
12051258
12061259 auto resultTargetVecType =
12071260 VectorType::get (targetType.getShape ().drop_back (dimsToDrop),
12081261 targetType.getElementType ());
12091262
1210- MemRefType resultMemrefType;
1211- MemRefLayoutAttrInterface layout = srcType.getLayout ();
1212- if (isa<AffineMapAttr>(layout) && layout.isIdentity ()) {
1213- resultMemrefType = MemRefType::get (
1214- srcType.getShape ().drop_back (dimsToDrop), srcType.getElementType (),
1215- nullptr , srcType.getMemorySpace ());
1216- } else {
1217- MemRefLayoutAttrInterface updatedLayout;
1218- if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1219- auto strides =
1220- llvm::to_vector (strided.getStrides ().drop_back (dimsToDrop));
1221- updatedLayout = StridedLayoutAttr::get (strided.getContext (),
1222- strided.getOffset (), strides);
1223- } else {
1224- AffineMap map = srcType.getLayout ().getAffineMap ();
1225- int numSymbols = map.getNumSymbols ();
1226- for (size_t i = 0 ; i < dimsToDrop; ++i) {
1227- int dim = srcType.getRank () - i - 1 ;
1228- map = map.replace (rewriter.getAffineDimExpr (dim),
1229- rewriter.getAffineConstantExpr (0 ),
1230- map.getNumDims () - 1 , numSymbols);
1231- }
1232- }
1233- resultMemrefType = MemRefType::get (
1234- srcType.getShape ().drop_back (dimsToDrop), srcType.getElementType (),
1235- updatedLayout, srcType.getMemorySpace ());
1236- }
1237-
12381263 auto loc = readOp.getLoc ();
1264+ MemRefType resultMemrefType =
1265+ getMemRefTypeWithDroppingInnerDims (rewriter, srcType, dimsToDrop);
12391266 SmallVector<int64_t > offsets (srcType.getRank (), 0 );
12401267 SmallVector<int64_t > strides (srcType.getRank (), 1 );
12411268
@@ -1261,6 +1288,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
12611288 }
12621289};
12631290
1291+ // / Drop inner most contiguous unit dimensions from transfer_write operand.
1292+ // / E.g.,
1293+ // / vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1294+ // / {in_bounds = [true, true, true, true, true]}
1295+ // / : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1296+ // /
1297+ // / will be replaced with
1298+ // /
1299+ // / %subview = memref.subview %arg0
1300+ // / [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1301+ // / : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1302+ // / %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1303+ // / to vector<1x16x16xf32>
1304+ // / vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1305+ // / {in_bounds = [true, true, true]}
1306+ // / : vector<1x16x16xf32>, memref<1x512x16xf32>
1307+ class DropInnerMostUnitDimsTransferWrite
1308+ : public OpRewritePattern<vector::TransferWriteOp> {
1309+ using OpRewritePattern::OpRewritePattern;
1310+
1311+ LogicalResult matchAndRewrite (vector::TransferWriteOp writeOp,
1312+ PatternRewriter &rewriter) const override {
1313+ // TODO: support 0-d corner case.
1314+ if (writeOp.getTransferRank () == 0 )
1315+ return failure ();
1316+
1317+ // TODO: support mask.
1318+ if (writeOp.getMask ())
1319+ return failure ();
1320+
1321+ auto srcType = dyn_cast<MemRefType>(writeOp.getSource ().getType ());
1322+ if (!srcType || !srcType.hasStaticShape ())
1323+ return failure ();
1324+
1325+ if (!writeOp.getPermutationMap ().isMinorIdentity ())
1326+ return failure ();
1327+
1328+ auto targetType = writeOp.getVectorType ();
1329+ if (targetType.getRank () <= 1 )
1330+ return failure ();
1331+
1332+ FailureOr<size_t > maybeDimsToDrop =
1333+ getTransferFoldableInnerUnitDims (srcType, targetType);
1334+ if (failed (maybeDimsToDrop))
1335+ return failure ();
1336+
1337+ size_t dimsToDrop = maybeDimsToDrop.value ();
1338+ if (dimsToDrop == 0 )
1339+ return failure ();
1340+
1341+ auto resultTargetVecType =
1342+ VectorType::get (targetType.getShape ().drop_back (dimsToDrop),
1343+ targetType.getElementType ());
1344+
1345+ MemRefType resultMemrefType =
1346+ getMemRefTypeWithDroppingInnerDims (rewriter, srcType, dimsToDrop);
1347+ SmallVector<int64_t > offsets (srcType.getRank (), 0 );
1348+ SmallVector<int64_t > strides (srcType.getRank (), 1 );
1349+ ArrayAttr inBoundsAttr =
1350+ writeOp.getInBounds ()
1351+ ? rewriter.getArrayAttr (
1352+ writeOp.getInBoundsAttr ().getValue ().drop_back (dimsToDrop))
1353+ : ArrayAttr ();
1354+
1355+ Location loc = writeOp.getLoc ();
1356+ Value rankedReducedView = rewriter.create <memref::SubViewOp>(
1357+ loc, resultMemrefType, writeOp.getSource (), offsets, srcType.getShape (),
1358+ strides);
1359+ auto permMap = getTransferMinorIdentityMap (
1360+ cast<ShapedType>(rankedReducedView.getType ()), resultTargetVecType);
1361+
1362+ auto shapeCast = rewriter.createOrFold <vector::ShapeCastOp>(
1363+ loc, resultTargetVecType, writeOp.getVector ());
1364+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
1365+ writeOp, shapeCast, rankedReducedView,
1366+ writeOp.getIndices ().drop_back (dimsToDrop), AffineMapAttr::get (permMap),
1367+ // TODO: support mask.
1368+ /* mask=*/ Value (), inBoundsAttr);
1369+ return success ();
1370+ }
1371+ };
1372+
12641373// / Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
12651374// / semantics to a contraction suitable for MMT (matrix matrix multiplication
12661375// / with the RHS transposed) lowering.
@@ -1696,7 +1805,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
16961805void mlir::vector::
16971806 populateVectorTransferCollapseInnerMostContiguousDimsPatterns (
16981807 RewritePatternSet &patterns, PatternBenefit benefit) {
1699- patterns.add <DropInnerMostUnitDims>(patterns.getContext (), benefit);
1808+ patterns.add <DropInnerMostUnitDimsTransferRead,
1809+ DropInnerMostUnitDimsTransferWrite>(patterns.getContext (),
1810+ benefit);
17001811}
17011812
17021813void mlir::vector::populateSinkVectorBroadcastPatterns (
0 commit comments